/*-------------------------------------------------------------------------
 *
 * shard_rebalancer.c
 *
 * Function definitions for the shard rebalancer tool.
 *
 * Copyright (c) Citus Data, Inc.
 *
 * $Id$
 *
 *-------------------------------------------------------------------------
 */

#include <math.h>

#include "postgres.h"

#include "funcapi.h"
#include "libpq/libpq-fe.h"
#include "miscadmin.h"

#include "access/genam.h"
#include "access/htup.h"
#include "catalog/pg_proc.h"
#include "catalog/pg_type.h"
#include "commands/dbcommands.h"
#include "commands/sequence.h"
#ifdef DISABLE_OG_COMMENTS
#include "common/hashfn.h"
#else
#include "utils/hsearch.h"
#endif
#include "postmaster/postmaster.h"
#include "storage/lmgr.h"
#include "utils/builtins.h"
#include "utils/fmgroids.h"
#include "utils/guc_tables.h"
#include "utils/json.h"
#include "utils/lsyscache.h"
#include "utils/memutils.h"
#include "utils/pg_lsn.h"
#include "utils/syscache.h"
#ifdef DISABLE_OG_COMMENTS
#include "utils/varlena.h"
#endif
#include "pg_version_constants.h"

#include "distributed/argutils.h"
#include "distributed/background_jobs.h"
#include "distributed/citus_ruleutils.h"
#include "distributed/citus_safe_lib.h"
#include "distributed/commands/citus_sequence.h"
#include "distributed/commands.h"
#include "distributed/colocation_utils.h"
#include "distributed/commands/utility_hook.h"
#include "distributed/connection_management.h"
#include "distributed/coordinator_protocol.h"
#include "distributed/enterprise.h"
#include "distributed/hash_helpers.h"
#include "distributed/listutils.h"
#include "distributed/lock_graph.h"
#include "distributed/metadata_cache.h"
#include "distributed/metadata_utility.h"
#include "distributed/multi_progress.h"
#include "distributed/multi_server_executor.h"
#include "distributed/pg_dist_rebalance_strategy.h"
#include "distributed/reference_table_utils.h"
#include "distributed/remote_commands.h"
#include "distributed/resource_lock.h"
#include "distributed/shard_cleaner.h"
#include "distributed/shard_rebalancer.h"
#include "distributed/shard_transfer.h"
#include "distributed/session_ctx.h"
#include "distributed/tuplestore.h"
#include "distributed/utils/array_type.h"
#include "distributed/worker_protocol.h"

/* RebalanceOptions are the options used to control the rebalance algorithm */
typedef struct RebalanceOptions {
    List* relationIdList;
    float4 threshold;
    int32 maxShardMoves;
    ArrayType* excludedShardArray;
    bool drainOnly;
    float4 improvementThreshold;
    Form_pg_dist_rebalance_strategy rebalanceStrategy;
    const char* operationName;
    WorkerNode* workerNode;
} RebalanceOptions;

/*
 * RebalanceState is used to keep the internal state of the rebalance
 * algorithm in one place.
 */
typedef struct RebalanceState {
    /*
     * placementsHash contains the current state of all shard placements, it
     * is initialized from pg_dist_placement and is then modified based on the
     * found shard moves.
     */
    HTAB* placementsHash;

    /*
     * placementUpdateList contains all of the updates that have been done to
     * reach the current state of placementsHash.
     */
    List* placementUpdateList;
    RebalancePlanFunctions* functions;

    /*
     * fillStateListDesc contains all NodeFillStates ordered from full nodes to
     * empty nodes.
     */
    List* fillStateListDesc;

    /*
     * fillStateListAsc contains all NodeFillStates ordered from empty nodes to
     * full nodes.
     */
    List* fillStateListAsc;

    /*
     * disallowedPlacementList contains all placements that currently exist,
     * but are not allowed according to the shardAllowedOnNode function.
     */
    List* disallowedPlacementList;

    /*
     * totalCost is the cost of all the shards in the cluster added together.
     */
    float4 totalCost;

    /*
     * totalCapacity is the capacity of all the nodes in the cluster added
     * together.
     */
    float4 totalCapacity;

    /*
     * ignoredMoves is the number of moves that were ignored. This is used to
     * limit the amount of loglines we send.
     */
    int64 ignoredMoves;
} RebalanceState;

/* RebalanceContext stores the context for the function callbacks */
typedef struct RebalanceContext {
    FmgrInfo shardCostUDF;
    FmgrInfo nodeCapacityUDF;
    FmgrInfo shardAllowedOnNodeUDF;
} RebalanceContext;

/* WorkerHashKey contains hostname and port to be used as a key in a hash */
typedef struct WorkerHashKey {
    char hostname[MAX_NODE_LENGTH];
    int port;
} WorkerHashKey;

/* WorkerShardIds represents a set of shardIds grouped by worker */
typedef struct WorkerShardIds {
    WorkerHashKey worker;

    /* This is a uint64 hashset representing the shard ids for a specific worker */
    HTAB* shardIds;
} WorkerShardIds;

/* ShardStatistics contains statistics about a shard */
typedef struct ShardStatistics {
    uint64 shardId;

    /* The shard its size in bytes. */
    uint64 totalSize;
    XLogRecPtr shardLSN;
} ShardStatistics;

/*
 * WorkerShardStatistics represents a set of statistics about shards,
 * grouped by worker.
 */
typedef struct WorkerShardStatistics {
    WorkerHashKey worker;
    XLogRecPtr workerLSN;

    /*
     * Statistics for each shard on this worker:
     * key: shardId
     * value: ShardStatistics
     */
    HTAB* statistics;
} WorkerShardStatistics;

/*
 * ShardMoveDependencyHashEntry contains the taskId which any new shard
 * move task within the corresponding colocation group
 * must take a dependency on
 */
typedef struct ShardMoveDependencyInfo {
    int64 key;
    int64 taskId;
} ShardMoveDependencyInfo;

/*
 * ShardMoveSourceNodeHashEntry keeps track of the source nodes
 * of the moves.
 */
typedef struct ShardMoveSourceNodeHashEntry {
    /* this is the key */
    int32 node_id;
    List* taskIds;
} ShardMoveSourceNodeHashEntry;

/*
 * ShardMoveDependencies keeps track of all needed dependencies
 * between shard moves.
 */
typedef struct ShardMoveDependencies {
    HTAB* colocationDependencies;
    HTAB* nodeDependencies;
} ShardMoveDependencies;

/* static declarations for main logic */
static int ShardActivePlacementCount(HTAB* activePlacementsHash, uint64 shardId,
                                     List* activeWorkerNodeList);
static void UpdateShardPlacement(PlacementUpdateEvent* placementUpdateEvent,
                                 List* responsiveNodeList, Oid shardReplicationModeOid);

/* static declarations for main logic's utility functions */
static HTAB* ShardPlacementsListToHash(List* shardPlacementList);
static bool PlacementsHashFind(HTAB* placementsHash, uint64 shardId,
                               WorkerNode* workerNode);
static void PlacementsHashEnter(HTAB* placementsHash, uint64 shardId,
                                WorkerNode* workerNode);
static void PlacementsHashRemove(HTAB* placementsHash, uint64 shardId,
                                 WorkerNode* workerNode);
static int PlacementsHashCompare(const void* lhsKey, const void* rhsKey, Size keySize);
static uint32 PlacementsHashHashCode(const void* key, Size keySize);
static bool WorkerNodeListContains(List* workerNodeList, const char* workerName,
                                   uint32 workerPort);
static void UpdateColocatedShardPlacementProgress(uint64 shardId, char* sourceName,
                                                  int sourcePort, uint64 progress);
static NodeFillState* FindFillStateForPlacement(RebalanceState* state,
                                                ShardPlacement* placement);
static RebalanceState* InitRebalanceState(List* workerNodeList, List* shardPlacementList,
                                          RebalancePlanFunctions* functions);
static void MoveShardsAwayFromDisallowedNodes(RebalanceState* state);
static bool FindAndMoveShardCost(float4 utilizationLowerBound,
                                 float4 utilizationUpperBound,
                                 float4 improvementThreshold, RebalanceState* state);
static NodeFillState* FindAllowedTargetFillState(RebalanceState* state, uint64 shardId);
static void MoveShardCost(NodeFillState* sourceFillState, NodeFillState* targetFillState,
                          ShardCost* shardCost, RebalanceState* state);
static int CompareNodeFillStateAsc(const void* void1, const void* void2);
static int CompareNodeFillStateDesc(const void* void1, const void* void2);
static int CompareShardCostAsc(const void* void1, const void* void2);
static int CompareShardCostDesc(const void* void1, const void* void2);
static int CompareDisallowedPlacementAsc(const void* void1, const void* void2);
static int CompareDisallowedPlacementDesc(const void* void1, const void* void2);
static bool ShardAllowedOnNode(uint64 shardId, WorkerNode* workerNode, void* context);
static float4 NodeCapacity(WorkerNode* workerNode, void* context);
static ShardCost GetShardCost(uint64 shardId, void* context);
static List* NonColocatedDistRelationIdList(void);
static void RebalanceTableShards(RebalanceOptions* options, Oid shardReplicationModeOid);
static int64 RebalanceTableShardsBackground(RebalanceOptions* options,
                                            Oid shardReplicationModeOid);
static void AcquireRebalanceColocationLock(Oid relationId, const char* operationName);
static void ExecutePlacementUpdates(List* placementUpdateList,
                                    Oid shardReplicationModeOid, char* noticeOperation);
static float4 CalculateUtilization(float4 totalCost, float4 capacity);
static Form_pg_dist_rebalance_strategy GetRebalanceStrategy(Name name);
static void EnsureShardCostUDF(Oid functionOid);
static void EnsureNodeCapacityUDF(Oid functionOid);
static void EnsureShardAllowedOnNodeUDF(Oid functionOid);
static HTAB* BuildWorkerShardStatisticsHash(PlacementUpdateEventProgress* steps,
                                            int stepCount);
static HTAB* GetShardStatistics(MultiConnection* connection, HTAB* shardIds);
static HTAB* GetMovedShardIdsByWorker(PlacementUpdateEventProgress* steps, int stepCount,
                                      bool fromSource);
static uint64 WorkerShardSize(HTAB* workerShardStatistics, char* workerName,
                              int workerPort, uint64 shardId);
static XLogRecPtr WorkerShardLSN(HTAB* workerShardStatisticsHash, char* workerName,
                                 int workerPort, uint64 shardId);
static XLogRecPtr WorkerLSN(HTAB* workerShardStatisticsHash, char* workerName,
                            int workerPort);
static void AddToWorkerShardIdSet(HTAB* shardsByWorker, char* workerName, int workerPort,
                                  uint64 shardId);
static HTAB* BuildShardSizesHash(ProgressMonitorData* monitor, HTAB* shardStatistics);
static void ErrorOnConcurrentRebalance(RebalanceOptions*);
static List* GetSetCommandListForNewConnections(void);
static int64 GetColocationId(PlacementUpdateEvent* move);
static ShardMoveDependencies InitializeShardMoveDependencies();
static int64* GenerateTaskMoveDependencyList(PlacementUpdateEvent* move,
                                             int64 colocationId,
                                             ShardMoveDependencies shardMoveDependencies,
                                             int* nDepends);
static void UpdateShardMoveDependencies(PlacementUpdateEvent* move, uint64 colocationId,
                                        int64 taskId,
                                        ShardMoveDependencies shardMoveDependencies);
static XLogRecPtr GetRemoteLSN(MultiConnection* connection);
/* declarations for dynamic loading */
PG_FUNCTION_INFO_V1(rebalance_table_shards);
PG_FUNCTION_INFO_V1(replicate_table_shards);
PG_FUNCTION_INFO_V1(get_rebalance_table_shards_plan);
PG_FUNCTION_INFO_V1(get_rebalance_progress);
PG_FUNCTION_INFO_V1(citus_drain_node);
PG_FUNCTION_INFO_V1(citus_shard_cost_by_disk_size);
PG_FUNCTION_INFO_V1(citus_validate_rebalance_strategy_functions);
PG_FUNCTION_INFO_V1(pg_dist_rebalance_strategy_enterprise_check);
PG_FUNCTION_INFO_V1(spq_rebalance_start);
PG_FUNCTION_INFO_V1(citus_rebalance_stop);
PG_FUNCTION_INFO_V1(citus_rebalance_wait);

extern "C" Datum rebalance_table_shards(PG_FUNCTION_ARGS);
extern "C" Datum replicate_table_shards(PG_FUNCTION_ARGS);
extern "C" Datum get_rebalance_table_shards_plan(PG_FUNCTION_ARGS);
extern "C" Datum citus_drain_node(PG_FUNCTION_ARGS);
extern "C" Datum citus_shard_cost_by_disk_size(PG_FUNCTION_ARGS);
extern "C" Datum citus_validate_rebalance_strategy_functions(PG_FUNCTION_ARGS);
extern "C" Datum pg_dist_rebalance_strategy_enterprise_check(PG_FUNCTION_ARGS);
extern "C" Datum spq_rebalance_start(PG_FUNCTION_ARGS);
extern "C" Datum citus_rebalance_stop(PG_FUNCTION_ARGS);
extern "C" Datum citus_rebalance_wait(PG_FUNCTION_ARGS);

extern "C" Datum get_rebalance_progress(PG_FUNCTION_ARGS);

static const char* PlacementUpdateTypeNames[] = {
    [PLACEMENT_UPDATE_INVALID_FIRST] = "unknown",
    [PLACEMENT_UPDATE_MOVE] = "move",
    [PLACEMENT_UPDATE_COPY] = "copy",
};

static const char* PlacementUpdateStatusNames[] = {
    [PLACEMENT_UPDATE_STATUS_NOT_STARTED_YET] = "Not Started Yet",
    [PLACEMENT_UPDATE_STATUS_SETTING_UP] = "Setting Up",
    [PLACEMENT_UPDATE_STATUS_COPYING_DATA] = "Copying Data",
    [PLACEMENT_UPDATE_STATUS_CATCHING_UP] = "Catching Up",
    [PLACEMENT_UPDATE_STATUS_CREATING_CONSTRAINTS] = "Creating Constraints",
    [PLACEMENT_UPDATE_STATUS_FINAL_CATCH_UP] = "Final Catchup",
    [PLACEMENT_UPDATE_STATUS_CREATING_FOREIGN_KEYS] = "Creating Foreign Keys",
    [PLACEMENT_UPDATE_STATUS_COMPLETING] = "Completing",
    [PLACEMENT_UPDATE_STATUS_COMPLETED] = "Completed",
};

#ifdef USE_ASSERT_CHECKING

/*
 * Check that all the invariants of the state hold.
 */
static void CheckRebalanceStateInvariants(const RebalanceState* state)
{
    NodeFillState* fillState = NULL;
    NodeFillState* prevFillState = NULL;
    int fillStateIndex = 0;
    int fillStateLength = list_length(state->fillStateListAsc);

    Assert(state != NULL);
    Assert(list_length(state->fillStateListAsc) == list_length(state->fillStateListDesc));
    foreach_declared_ptr(fillState, state->fillStateListAsc)
    {
        float4 totalCost = 0;
        ShardCost* shardCost = NULL;
        ShardCost* prevShardCost = NULL;
        if (prevFillState != NULL) {
            /* Check that the previous fill state is more empty than this one */
            bool higherUtilization = fillState->utilization > prevFillState->utilization;
            bool sameUtilization = fillState->utilization == prevFillState->utilization;
            bool lowerOrSameCapacity = fillState->capacity <= prevFillState->capacity;
            Assert(higherUtilization || (sameUtilization && lowerOrSameCapacity));
        }

        /* Check that fillStateListDesc is the reversed version of fillStateListAsc */
        Assert(list_nth(state->fillStateListDesc, fillStateLength - fillStateIndex - 1) ==
               fillState);

        foreach_declared_ptr(shardCost, fillState->shardCostListDesc)
        {
            if (prevShardCost != NULL) {
                /* Check that shard costs are sorted in descending order */
                Assert(shardCost->cost <= prevShardCost->cost);
            }
            totalCost += shardCost->cost;
            prevShardCost = shardCost;
        }

        /* Check that utilization field is up to date. */
        Assert(
            fillState->utilization ==
            CalculateUtilization(fillState->totalCost,
                                 fillState->capacity)); /* lgtm[cpp/equality-on-floats] */

        /*
         * Check that fillState->totalCost is within 0.1% difference of
         * sum(fillState->shardCostListDesc->cost)
         * We cannot compare exactly, because these numbers are floats and
         * fillState->totalCost is modified by doing + and - on it. So instead
         * we check that the numbers are roughly the same.
         */
        float4 absoluteDifferenceBetweenTotalCosts =
            fabsf(fillState->totalCost - totalCost);
        float4 maximumAbsoluteValueOfTotalCosts =
            fmaxf(fabsf(fillState->totalCost), fabsf(totalCost));
        Assert(absoluteDifferenceBetweenTotalCosts <=
               maximumAbsoluteValueOfTotalCosts / 1000);

        prevFillState = fillState;
        fillStateIndex++;
    }
}

#else
#define CheckRebalanceStateInvariants(l) ((void)0)
#endif /* USE_ASSERT_CHECKING */

/*
 * BigIntArrayDatumContains checks if the array contains the given number.
 */
static bool BigIntArrayDatumContains(Datum* array, int arrayLength, uint64 toFind)
{
    for (int i = 0; i < arrayLength; i++) {
        if (DatumGetInt64(array[i]) == toFind) {
            return true;
        }
    }
    return false;
}

/*
 * FullShardPlacementList returns a List containing all the shard placements of
 * a specific table (excluding the excludedShardArray)
 */
static List* FullShardPlacementList(Oid relationId, ArrayType* excludedShardArray)
{
    List* shardPlacementList = NIL;
    CitusTableCacheEntry* citusTableCacheEntry = GetCitusTableCacheEntry(relationId);
    int shardIntervalArrayLength = citusTableCacheEntry->shardIntervalArrayLength;
    int excludedShardIdCount = ArrayObjectCount(excludedShardArray);
    Datum* excludedShardArrayDatum = DeconstructArrayObject(excludedShardArray);

    for (int shardIndex = 0; shardIndex < shardIntervalArrayLength; shardIndex++) {
        ShardInterval* shardInterval =
            citusTableCacheEntry->sortedShardIntervalArray[shardIndex];
        GroupShardPlacement* placementArray =
            citusTableCacheEntry->arrayOfPlacementArrays[shardIndex];
        int numberOfPlacements =
            citusTableCacheEntry->arrayOfPlacementArrayLengths[shardIndex];

        if (BigIntArrayDatumContains(excludedShardArrayDatum, excludedShardIdCount,
                                     shardInterval->shardId)) {
            continue;
        }

        for (int placementIndex = 0; placementIndex < numberOfPlacements;
             placementIndex++) {
            GroupShardPlacement* groupPlacement = &placementArray[placementIndex];
            WorkerNode* worker = LookupNodeForGroup(groupPlacement->groupId);
            ShardPlacement* placement = CitusMakeNode(ShardPlacement);
            placement->shardId = groupPlacement->shardId;
            placement->shardLength = groupPlacement->shardLength;
            placement->nodeId = worker->nodeId;
            placement->nodeName = pstrdup(worker->workerName);
            placement->nodePort = worker->workerPort;
            placement->placementId = groupPlacement->placementId;

            shardPlacementList = lappend(shardPlacementList, placement);
        }
    }
    return SortList(shardPlacementList, CompareShardPlacements);
}

/*
 * SortedActiveWorkers returns all the active workers like
 * ActiveReadableNodeList, but sorted.
 */
static List* SortedActiveWorkers()
{
    List* activeWorkerList = ActiveReadableNodeList();
    return SortList(activeWorkerList, CompareWorkerNodes);
}

/*
 * GetRebalanceSteps returns a List of PlacementUpdateEvents that are needed to
 * rebalance a list of tables.
 */
static List* GetRebalanceSteps(RebalanceOptions* options)
{
    EnsureShardCostUDF(options->rebalanceStrategy->shardCostFunction);
    EnsureNodeCapacityUDF(options->rebalanceStrategy->nodeCapacityFunction);
    EnsureShardAllowedOnNodeUDF(options->rebalanceStrategy->shardAllowedOnNodeFunction);

    RebalanceContext context;
    memset(&context, 0, sizeof(RebalanceContext));
    fmgr_info(options->rebalanceStrategy->shardCostFunction, &context.shardCostUDF);
    fmgr_info(options->rebalanceStrategy->nodeCapacityFunction, &context.nodeCapacityUDF);
    fmgr_info(options->rebalanceStrategy->shardAllowedOnNodeFunction,
              &context.shardAllowedOnNodeUDF);

    RebalancePlanFunctions rebalancePlanFunctions = {
        .shardAllowedOnNode = ShardAllowedOnNode,
        .nodeCapacity = NodeCapacity,
        .shardCost = GetShardCost,
        .context = &context,
    };

    /* sort the lists to make the function more deterministic */
    List* activeWorkerList = SortedActiveWorkers();
    int shardAllowedNodeCount = 0;
    WorkerNode* workerNode = NULL;
    foreach_declared_ptr(workerNode, activeWorkerList)
    {
        if (workerNode->shouldHaveShards) {
            shardAllowedNodeCount++;
        }
    }

    if (shardAllowedNodeCount < Session_ctx::Vars().ShardReplicationFactor) {
        ereport(ERROR, (errmsg("Shard replication factor (%d) cannot be greater than "
                               "number of nodes with should_have_shards=true (%d).",
                               Session_ctx::Vars().ShardReplicationFactor,
                               shardAllowedNodeCount)));
    }

    List* activeShardPlacementListList = NIL;
    List* unbalancedShards = NIL;

    Oid relationId = InvalidOid;
    foreach_declared_oid(relationId, options->relationIdList)
    {
        List* shardPlacementList =
            FullShardPlacementList(relationId, options->excludedShardArray);
        List* activeShardPlacementListForRelation =
            FilterShardPlacementList(shardPlacementList, IsActiveShardPlacement);

        if (options->workerNode != NULL) {
            activeShardPlacementListForRelation = FilterActiveShardPlacementListByNode(
                shardPlacementList, options->workerNode);
        }

        if (list_length(activeShardPlacementListForRelation) >= shardAllowedNodeCount) {
            activeShardPlacementListList = lappend(activeShardPlacementListList,
                                                   activeShardPlacementListForRelation);
        } else {
            /*
             * If the number of shard groups are less than the number of worker nodes,
             * at least one of the worker nodes will remain empty. For such cases,
             * we consider those shard groups as a colocation group and try to
             * distribute them across the cluster.
             */
            unbalancedShards =
                list_concat(unbalancedShards, activeShardPlacementListForRelation);
        }
    }

    if (list_length(unbalancedShards) > 0) {
        activeShardPlacementListList =
            lappend(activeShardPlacementListList, unbalancedShards);
    }

    if (options->threshold < options->rebalanceStrategy->minimumThreshold) {
        ereport(WARNING, (errmsg("the given threshold is lower than the minimum "
                                 "threshold allowed by the rebalance strategy, "
                                 "using the minimum allowed threshold instead"),
                          errdetail("Using threshold of %.2f",
                                    options->rebalanceStrategy->minimumThreshold)));
        options->threshold = options->rebalanceStrategy->minimumThreshold;
    }

    return RebalancePlacementUpdates(activeWorkerList, activeShardPlacementListList,
                                     options->threshold, options->maxShardMoves,
                                     options->drainOnly, options->improvementThreshold,
                                     &rebalancePlanFunctions);
}

/*
 * ShardAllowedOnNode determines if shard is allowed on a specific worker node.
 */
static bool ShardAllowedOnNode(uint64 shardId, WorkerNode* workerNode, void* voidContext)
{
    if (!workerNode->shouldHaveShards) {
        return false;
    }

    RebalanceContext* context = static_cast<RebalanceContext*>(voidContext);
    Datum allowed =
        FunctionCall2(&context->shardAllowedOnNodeUDF, shardId, workerNode->nodeId);
    return DatumGetBool(allowed);
}

/*
 * NodeCapacity returns the relative capacity of a node. A node with capacity 2
 * can contain twice as many shards as a node with capacity 1. The actual
 * capacity can be a number grounded in reality, like the disk size, number of
 * cores, but it doesn't have to be.
 */
static float4 NodeCapacity(WorkerNode* workerNode, void* voidContext)
{
    if (!workerNode->shouldHaveShards) {
        return 0;
    }

    RebalanceContext* context = static_cast<RebalanceContext*>(voidContext);
    Datum capacity = FunctionCall1(&context->nodeCapacityUDF, workerNode->nodeId);
    return DatumGetFloat4(capacity);
}

/*
 * GetShardCost returns the cost of the given shard. A shard with cost 2 will
 * be weighted as heavily as two shards with cost 1. This cost number can be a
 * number grounded in reality, like the shard size on disk, but it doesn't have
 * to be.
 */
static ShardCost GetShardCost(uint64 shardId, void* voidContext)
{
    ShardCost shardCost = {0};
    shardCost.shardId = shardId;
    RebalanceContext* context = static_cast<RebalanceContext*>(voidContext);
    Datum shardCostDatum = FunctionCall1(&context->shardCostUDF, UInt64GetDatum(shardId));
    shardCost.cost = DatumGetFloat4(shardCostDatum);
    return shardCost;
}

/*
 * citus_shard_cost_by_disk_size gets the cost for a shard based on the disk
 * size of the shard on a worker. The worker to check the disk size is
 * determined by choosing the first active placement for the shard. The disk
 * size is calculated using pg_total_relation_size, so it includes indexes.
 *
 * SQL signature:
 * citus_shard_cost_by_disk_size(shardid bigint) returns float4
 */
Datum citus_shard_cost_by_disk_size(PG_FUNCTION_ARGS)
{
    CheckCitusVersion(ERROR);
    uint64 shardId = PG_GETARG_INT64(0);
    bool missingOk = false;
    ShardPlacement* shardPlacement = ActiveShardPlacement(shardId, missingOk);

    MemoryContext localContext = AllocSetContextCreate(
        CurrentMemoryContext, "CostByDiscSizeContext", ALLOCSET_DEFAULT_SIZES);
    MemoryContext oldContext = MemoryContextSwitchTo(localContext);
    ShardInterval* shardInterval = LoadShardInterval(shardId);
    List* colocatedShardList = ColocatedNonPartitionShardIntervalList(shardInterval);

    uint64 colocationSizeInBytes = ShardListSizeInBytes(
        colocatedShardList, shardPlacement->nodeName, shardPlacement->nodePort);

    MemoryContextSwitchTo(oldContext);
    MemoryContextReset(localContext);

    colocationSizeInBytes += Session_ctx::Vars().RebalancerByDiskSizeBaseCost;

    if (colocationSizeInBytes <= 0) {
        PG_RETURN_FLOAT4(1);
    }

    PG_RETURN_FLOAT4(colocationSizeInBytes);
}

/*
 * GetColocatedRebalanceSteps takes a List of PlacementUpdateEvents and creates
 * a new List of containing those and all the updates for colocated shards.
 */
static List* GetColocatedRebalanceSteps(List* placementUpdateList)
{
    ListCell* placementUpdateCell = NULL;
    List* colocatedUpdateList = NIL;

    foreach (placementUpdateCell, placementUpdateList) {
        PlacementUpdateEvent* placementUpdate =
            static_cast<PlacementUpdateEvent*>(lfirst(placementUpdateCell));
        ShardInterval* shardInterval = LoadShardInterval(placementUpdate->shardId);
        List* colocatedShardList = ColocatedShardIntervalList(shardInterval);
        ListCell* colocatedShardCell = NULL;

        foreach (colocatedShardCell, colocatedShardList) {
            ShardInterval* colocatedShard =
                static_cast<ShardInterval*>(lfirst(colocatedShardCell));
            PlacementUpdateEvent* colocatedUpdate =
                static_cast<PlacementUpdateEvent*>(palloc0(sizeof(PlacementUpdateEvent)));

            colocatedUpdate->shardId = colocatedShard->shardId;
            colocatedUpdate->sourceNode = placementUpdate->sourceNode;
            colocatedUpdate->targetNode = placementUpdate->targetNode;
            colocatedUpdate->updateType = placementUpdate->updateType;

            colocatedUpdateList = lappend(colocatedUpdateList, colocatedUpdate);
        }
    }

    return colocatedUpdateList;
}

/*
 * AcquireRelationColocationLock tries to acquire a lock for
 * rebalance/replication. If this is it not possible it fails
 * instantly because this means another rebalance/replication
 * is currently happening. This would really mess up planning.
 */
static void AcquireRebalanceColocationLock(Oid relationId, const char* operationName)
{
    uint32 lockId = relationId;
    LOCKTAG tag;

    CitusTableCacheEntry* citusTableCacheEntry = GetCitusTableCacheEntry(relationId);
    if (citusTableCacheEntry->colocationId != INVALID_COLOCATION_ID) {
        lockId = citusTableCacheEntry->colocationId;
    }

    SET_LOCKTAG_REBALANCE_COLOCATION(tag, (int64)lockId);

    LockAcquireResult lockAcquired = LockAcquire(&tag, ExclusiveLock, false, true);
    if (!lockAcquired) {
        ereport(ERROR,
                (errmsg("could not acquire the lock required to %s %s", operationName,
                        generate_qualified_relation_name(relationId)),
                 errdetail("It means that either a concurrent shard move "
                           "or shard copy is happening."),
                 errhint("Make sure that the concurrent operation has "
                         "finished and re-run the command")));
    }
}

/*
 * AcquirePlacementColocationLock tries to acquire a lock for
 * rebalance/replication while moving/copying the placement. If this
 * is it not possible it fails instantly because this means
 * another move/copy is currently happening. This would really mess up planning.
 */
void AcquirePlacementColocationLock(Oid relationId, int lockMode,
                                    const char* operationName)
{
    uint32 lockId = relationId;
    LOCKTAG tag;

    CitusTableCacheEntry* citusTableCacheEntry = GetCitusTableCacheEntry(relationId);
    if (citusTableCacheEntry->colocationId != INVALID_COLOCATION_ID) {
        lockId = citusTableCacheEntry->colocationId;
    }

    SET_LOCKTAG_REBALANCE_PLACEMENT_COLOCATION(tag, (int64)lockId);

    LockAcquireResult lockAcquired = LockAcquire(&tag, lockMode, false, true);
    if (!lockAcquired) {
        ereport(ERROR,
                (errmsg("could not acquire the lock required to %s %s", operationName,
                        generate_qualified_relation_name(relationId)),
                 errdetail("It means that either a concurrent shard move "
                           "or colocated distributed table creation is "
                           "happening."),
                 errhint("Make sure that the concurrent operation has "
                         "finished and re-run the command")));
    }
}

/*
 * GetResponsiveWorkerList returns a List of workers that respond to new
 * connection requests.
 */
static List* GetResponsiveWorkerList()
{
    List* activeWorkerList = ActiveReadableNodeList();
    ListCell* activeWorkerCell = NULL;
    List* responsiveWorkerList = NIL;

    foreach (activeWorkerCell, activeWorkerList) {
        WorkerNode* worker = static_cast<WorkerNode*>(lfirst(activeWorkerCell));
        int connectionFlag = FORCE_NEW_CONNECTION;

        MultiConnection* connection =
            GetNodeConnection(connectionFlag, worker->workerName, worker->workerPort);

        if (connection != NULL && connection->pgConn != NULL) {
            if (PQstatus(connection->pgConn) == CONNECTION_OK) {
                responsiveWorkerList = lappend(responsiveWorkerList, worker);
            }

            CloseConnection(connection);
        }
    }
    return responsiveWorkerList;
}

/*
 * ExecutePlacementUpdates copies or moves a shard placement by calling the
 * corresponding functions in Citus in a separate subtransaction for each
 * update.
 */
static void ExecutePlacementUpdates(List* placementUpdateList,
                                    Oid shardReplicationModeOid, char* noticeOperation)
{
    List* responsiveWorkerList = GetResponsiveWorkerList();

    MemoryContext localContext = AllocSetContextCreate(
        CurrentMemoryContext, "ExecutePlacementLoopContext", ALLOCSET_DEFAULT_SIZES);
    MemoryContext oldContext = MemoryContextSwitchTo(localContext);

    ListCell* placementUpdateCell = NULL;

    DropOrphanedResourcesInSeparateTransaction();

    foreach (placementUpdateCell, placementUpdateList) {
        PlacementUpdateEvent* placementUpdate =
            static_cast<PlacementUpdateEvent*>(lfirst(placementUpdateCell));
        ereport(NOTICE,
                (errmsg("%s shard %lu from %s:%u to %s:%u ...", noticeOperation,
                        placementUpdate->shardId, placementUpdate->sourceNode->workerName,
                        placementUpdate->sourceNode->workerPort,
                        placementUpdate->targetNode->workerName,
                        placementUpdate->targetNode->workerPort)));
        UpdateShardPlacement(placementUpdate, responsiveWorkerList,
                             shardReplicationModeOid);
        MemoryContextReset(localContext);
    }
    MemoryContextSwitchTo(oldContext);
}

/*
 * SetupRebalanceMonitor initializes the dynamic shared memory required for storing the
 * progress information of a rebalance process. The function takes a List of
 * PlacementUpdateEvents for all shards that will be moved (including colocated
 * ones) and the relation id of the target table. The dynamic shared memory
 * portion consists of a RebalanceMonitorHeader and multiple
 * PlacementUpdateEventProgress, one for each planned shard placement move. The
 * dsm_handle of the created segment is saved in the progress of the current backend so
 * that it can be read by external agents such as get_rebalance_progress function by
 * calling pg_stat_get_progress_info UDF. Since currently only VACUUM commands are
 * officially allowed as the command type, we describe ourselves as a VACUUM command and
 * in order to distinguish a rebalancer progress from regular VACUUM progresses, we put
 * a magic number to the first progress field as an indicator. Finally we return the
 * dsm handle so that it can be used for updating the progress and cleaning things up.
 */
void SetupRebalanceMonitor(List* placementUpdateList, Oid relationId,
                           uint64 initialProgressState,
                           PlacementUpdateStatus initialStatus)
{
    // @TODO fixme later
#ifdef DISABLE_OG_COMMENTS
    List* colocatedUpdateList = GetColocatedRebalanceSteps(placementUpdateList);
    ListCell* colocatedUpdateCell = NULL;

    dsm_handle dsmHandle;
    ProgressMonitorData* monitor =
        CreateProgressMonitor(list_length(colocatedUpdateList),
                              sizeof(PlacementUpdateEventProgress), &dsmHandle);
    PlacementUpdateEventProgress* rebalanceSteps =
        static_cast<PlacementUpdateEventProgress*>(ProgressMonitorSteps(monitor));

    int32 eventIndex = 0;
    foreach (colocatedUpdateCell, colocatedUpdateList) {
        PlacementUpdateEvent* colocatedUpdate =
            static_cast<PlacementUpdateEvent*>(lfirst(colocatedUpdateCell));
        PlacementUpdateEventProgress* event = rebalanceSteps + eventIndex;

        strlcpy(event->sourceName, colocatedUpdate->sourceNode->workerName, 255);
        strlcpy(event->targetName, colocatedUpdate->targetNode->workerName, 255);

        event->shardId = colocatedUpdate->shardId;
        event->sourcePort = colocatedUpdate->sourceNode->workerPort;
        event->targetPort = colocatedUpdate->targetNode->workerPort;
        event->updateType = colocatedUpdate->updateType;
        pg_atomic_init_u64(&event->updateStatus, initialStatus);
        pg_atomic_init_u64(&event->progress, initialProgressState);

        eventIndex++;
    }
    RegisterProgressMonitor(REBALANCE_ACTIVITY_MAGIC_NUMBER, relationId, dsmHandle);
#endif
}

/*
 * rebalance_table_shards rebalances the shards across the workers.
 *
 * SQL signature:
 *
 * rebalance_table_shards(
 *     relation regclass,
 *     threshold float4,
 *     max_shard_moves int,
 *     excluded_shard_list bigint[],
 *     shard_transfer_mode spq.shard_transfer_mode,
 *     drain_only boolean,
 *     rebalance_strategy name
 * ) RETURNS VOID
 */
Datum rebalance_table_shards(PG_FUNCTION_ARGS)
{
    CheckCitusVersion(ERROR);
    List* relationIdList = NIL;
    if (!PG_ARGISNULL(0)) {
        Oid relationId = PG_GETARG_OID(0);
        ErrorIfMoveUnsupportedTableType(relationId);
        relationIdList = list_make1_oid(relationId);
    } else {
        /*
         * Note that we don't need to do any checks to error out for
         * citus local tables here as NonColocatedDistRelationIdList
         * already doesn't return non-distributed tables.
         */
        relationIdList = NonColocatedDistRelationIdList();
    }

    PG_ENSURE_ARGNOTNULL(2, "max_shard_moves");
    PG_ENSURE_ARGNOTNULL(3, "excluded_shard_list");
    PG_ENSURE_ARGNOTNULL(4, "shard_transfer_mode");
    PG_ENSURE_ARGNOTNULL(5, "drain_only");

    Form_pg_dist_rebalance_strategy strategy =
        GetRebalanceStrategy(PG_GETARG_NAME_OR_NULL(6));
    RebalanceOptions options = {
        .relationIdList = relationIdList,
        .threshold = PG_GETARG_FLOAT4_OR_DEFAULT(1, strategy->defaultThreshold),
        .maxShardMoves = PG_GETARG_INT32(2),
        .excludedShardArray = PG_GETARG_ARRAYTYPE_P(3),
        .drainOnly = PG_GETARG_BOOL(5),
        .improvementThreshold = strategy->improvementThreshold,
        .rebalanceStrategy = strategy};
    Oid shardTransferModeOid = PG_GETARG_OID(4);
    RebalanceTableShards(&options, shardTransferModeOid);
    PG_RETURN_VOID();
}

/*
 * spq_rebalance_start rebalances the shards across the workers.
 *
 * SQL signature:
 *
 * spq_rebalance_start(
 *     rebalance_strategy name DEFAULT NULL,
 *     drain_only boolean DEFAULT false,
 *     shard_transfer_mode spq.shard_transfer_mode default 'auto'
 * ) RETURNS VOID
 */
Datum spq_rebalance_start(PG_FUNCTION_ARGS)
{
    CheckCitusVersion(ERROR);
    List* relationIdList = NonColocatedDistRelationIdList();
    Form_pg_dist_rebalance_strategy strategy =
        GetRebalanceStrategy(PG_GETARG_NAME_OR_NULL(0));

    PG_ENSURE_ARGNOTNULL(1, "drain_only");
    bool drainOnly = PG_GETARG_BOOL(1);

    PG_ENSURE_ARGNOTNULL(2, "shard_transfer_mode");
    Oid shardTransferModeOid = PG_GETARG_OID(2);

    RebalanceOptions options = {.relationIdList = relationIdList,
                                .threshold = strategy->defaultThreshold,
                                .maxShardMoves = 10000000,
                                .excludedShardArray = construct_empty_array(INT4OID),
                                .drainOnly = drainOnly,
                                .improvementThreshold = strategy->improvementThreshold,
                                .rebalanceStrategy = strategy};
    int jobId = RebalanceTableShardsBackground(&options, shardTransferModeOid);

    if (jobId == 0) {
        PG_RETURN_NULL();
    }
    PG_RETURN_INT64(jobId);
}

/*
 * citus_rebalance_stop stops any ongoing background rebalance that is executing.
 * Raises an error when there is no backgound rebalance ongoing at the moment.
 */
Datum citus_rebalance_stop(PG_FUNCTION_ARGS)
{
    CheckCitusVersion(ERROR);

    int64 jobId = 0;
    if (!HasNonTerminalJobOfType("rebalance", &jobId)) {
        ereport(ERROR, (errmsg("no ongoing rebalance that can be stopped")));
    }

    DirectFunctionCall1(citus_job_cancel, Int64GetDatum(jobId));

    PG_RETURN_VOID();
}

/*
 * citus_rebalance_wait waits till an ongoing background rebalance has finished execution.
 * A warning will be displayed if no rebalance is ongoing.
 */
Datum citus_rebalance_wait(PG_FUNCTION_ARGS)
{
    CheckCitusVersion(ERROR);

    int64 jobId = 0;
    if (!HasNonTerminalJobOfType("rebalance", &jobId)) {
        ereport(WARNING, (errmsg("no ongoing rebalance that can be waited on")));
        PG_RETURN_VOID();
    }

    citus_job_wait_internal(jobId, NULL);

    PG_RETURN_VOID();
}

/*
 * GetRebalanceStrategy returns the rebalance strategy from
 * pg_dist_rebalance_strategy matching the given name. If name is NULL it
 * returns the default rebalance strategy from pg_dist_rebalance_strategy.
 */
static Form_pg_dist_rebalance_strategy GetRebalanceStrategy(Name name)
{
    Relation pgDistRebalanceStrategy =
        table_open(DistRebalanceStrategyRelationId(), AccessShareLock);

    const int scanKeyCount = 1;
    ScanKeyData scanKey[1];
    if (name == NULL) {
        /* WHERE default_strategy=true */
        ScanKeyInit(&scanKey[0], Anum_pg_dist_rebalance_strategy_default_strategy,
                    BTEqualStrategyNumber, F_BOOLEQ, BoolGetDatum(true));
    } else {
        /* WHERE name=$name */
        ScanKeyInit(&scanKey[0], Anum_pg_dist_rebalance_strategy_name,
                    BTEqualStrategyNumber, F_NAMEEQ, NameGetDatum(name));
    }
    SysScanDesc scanDescriptor = systable_beginscan(pgDistRebalanceStrategy, InvalidOid,
                                                    false, NULL, scanKeyCount, scanKey);

    HeapTuple heapTuple = systable_getnext(scanDescriptor);
    if (!HeapTupleIsValid(heapTuple)) {
        if (name == NULL) {
            ereport(ERROR, (errmsg("no rebalance_strategy was provided, but there is "
                                   "also no default strategy set")));
        }
        ereport(ERROR,
                (errmsg("could not find rebalance strategy with name %s", (char*)name)));
    }

    Form_pg_dist_rebalance_strategy strategy =
        (Form_pg_dist_rebalance_strategy)GETSTRUCT(heapTuple);
    Form_pg_dist_rebalance_strategy strategy_copy =
        static_cast<Form_pg_dist_rebalance_strategy>(
            palloc0(sizeof(FormData_pg_dist_rebalance_strategy)));

    /* Copy data over by dereferencing */
    *strategy_copy = *strategy;

    systable_endscan(scanDescriptor);
    table_close(pgDistRebalanceStrategy, NoLock);

    return strategy_copy;
}

/*
 * citus_drain_node drains a node by setting shouldhaveshards to false and
 * running the rebalancer after in drain_only mode.
 */
Datum citus_drain_node(PG_FUNCTION_ARGS)
{
    CheckCitusVersion(ERROR);
    PG_ENSURE_ARGNOTNULL(0, "nodename");
    PG_ENSURE_ARGNOTNULL(1, "nodeport");
    PG_ENSURE_ARGNOTNULL(2, "shard_transfer_mode");

    text* nodeNameText = PG_GETARG_TEXT_P(0);
    int32 nodePort = PG_GETARG_INT32(1);
    Oid shardTransferModeOid = PG_GETARG_OID(2);
    Form_pg_dist_rebalance_strategy strategy =
        GetRebalanceStrategy(PG_GETARG_NAME_OR_NULL(3));
    RebalanceOptions options = {
        .relationIdList = NonColocatedDistRelationIdList(),
        .threshold = strategy->defaultThreshold,
        .maxShardMoves = 0,
        .excludedShardArray = construct_empty_array(INT4OID),
        .drainOnly = true,
        .rebalanceStrategy = strategy,
    };

    char* nodeName = text_to_cstring(nodeNameText);
    options.workerNode = FindWorkerNodeOrError(nodeName, nodePort);

    /*
     * This is done in a separate session. This way it's not undone if the
     * draining fails midway through.
     */
    ExecuteRebalancerCommandInSeparateTransaction(
        psprintf("SELECT spq_set_node_property(%s, %i, 'shouldhaveshards', false)",
                 quote_literal_cstr(nodeName), nodePort));

    RebalanceTableShards(&options, shardTransferModeOid);

    PG_RETURN_VOID();
}

/*
 * replicate_table_shards replicates under-replicated shards of the specified
 * table.
 */
Datum replicate_table_shards(PG_FUNCTION_ARGS)
{
    CheckCitusVersion(ERROR);
    Oid relationId = PG_GETARG_OID(0);
    uint32 shardReplicationFactor = PG_GETARG_INT32(1);
    int32 maxShardCopies = PG_GETARG_INT32(2);
    ArrayType* excludedShardArray = PG_GETARG_ARRAYTYPE_P(3);
    Oid shardReplicationModeOid = PG_GETARG_OID(4);

    if (IsCitusTableType(relationId, SINGLE_SHARD_DISTRIBUTED)) {
        ereport(ERROR, (errmsg("cannot replicate single shard tables' shards")));
    }

    char transferMode = LookupShardTransferMode(shardReplicationModeOid);
    EnsureReferenceTablesExistOnAllNodesExtended(transferMode);

    AcquireRebalanceColocationLock(relationId, "replicate");

    List* activeWorkerList = SortedActiveWorkers();
    List* shardPlacementList = FullShardPlacementList(relationId, excludedShardArray);
    List* activeShardPlacementList =
        FilterShardPlacementList(shardPlacementList, IsActiveShardPlacement);

    List* placementUpdateList = ReplicationPlacementUpdates(
        activeWorkerList, activeShardPlacementList, shardReplicationFactor);
    placementUpdateList = list_truncate(placementUpdateList, maxShardCopies);

    ExecutePlacementUpdates(placementUpdateList, shardReplicationModeOid, "Copying");

    PG_RETURN_VOID();
}

/*
 * get_rebalance_table_shards_plan function calculates the shard move steps
 * required for the rebalance operations including the ones for colocated
 * tables.
 *
 * SQL signature:
 *
 * get_rebalance_table_shards_plan(
 *     relation regclass,
 *     threshold float4,
 *     max_shard_moves int,
 *     excluded_shard_list bigint[],
 *     drain_only boolean,
 *     rebalance_strategy name
 * )
 */
Datum get_rebalance_table_shards_plan(PG_FUNCTION_ARGS)
{
    CheckCitusVersion(ERROR);
    List* relationIdList = NIL;
    if (!PG_ARGISNULL(0)) {
        Oid relationId = PG_GETARG_OID(0);
        ErrorIfMoveUnsupportedTableType(relationId);

        relationIdList = list_make1_oid(relationId);
    } else {
        /*
         * Note that we don't need to do any checks to error out for
         * citus local tables here as NonColocatedDistRelationIdList
         * already doesn't return non-distributed tables.
         */
        relationIdList = NonColocatedDistRelationIdList();
    }

    PG_ENSURE_ARGNOTNULL(2, "max_shard_moves");
    PG_ENSURE_ARGNOTNULL(3, "excluded_shard_list");
    PG_ENSURE_ARGNOTNULL(4, "drain_only");

    Form_pg_dist_rebalance_strategy strategy =
        GetRebalanceStrategy(PG_GETARG_NAME_OR_NULL(5));
    RebalanceOptions options = {
        .relationIdList = relationIdList,
        .threshold = PG_GETARG_FLOAT4_OR_DEFAULT(1, strategy->defaultThreshold),
        .maxShardMoves = PG_GETARG_INT32(2),
        .excludedShardArray = PG_GETARG_ARRAYTYPE_P(3),
        .drainOnly = PG_GETARG_BOOL(4),
        .improvementThreshold =
            PG_GETARG_FLOAT4_OR_DEFAULT(6, strategy->improvementThreshold),
        .rebalanceStrategy = strategy};

    List* placementUpdateList = GetRebalanceSteps(&options);
    List* colocatedUpdateList = GetColocatedRebalanceSteps(placementUpdateList);
    ListCell* colocatedUpdateCell = NULL;

    TupleDesc tupdesc;
    Tuplestorestate* tupstore = SetupTuplestore(fcinfo, &tupdesc);

    foreach (colocatedUpdateCell, colocatedUpdateList) {
        PlacementUpdateEvent* colocatedUpdate =
            static_cast<PlacementUpdateEvent*>(lfirst(colocatedUpdateCell));
        Datum values[7];
        bool nulls[7];

        memset(values, 0, sizeof(values));
        memset(nulls, 0, sizeof(nulls));

        values[0] = ObjectIdGetDatum(RelationIdForShard(colocatedUpdate->shardId));
        values[1] = UInt64GetDatum(colocatedUpdate->shardId);
        values[2] = UInt64GetDatum(ShardLength(colocatedUpdate->shardId));
        values[3] =
            PointerGetDatum(cstring_to_text(colocatedUpdate->sourceNode->workerName));
        values[4] = UInt32GetDatum(colocatedUpdate->sourceNode->workerPort);
        values[5] =
            PointerGetDatum(cstring_to_text(colocatedUpdate->targetNode->workerName));
        values[6] = UInt32GetDatum(colocatedUpdate->targetNode->workerPort);

        tuplestore_putvalues(tupstore, tupdesc, values, nulls);
    }

    return (Datum)0;
}

/*
 * get_rebalance_progress collects information about the ongoing rebalance operations and
 * returns the concatenated list of steps involved in the operations, along with their
 * progress information. Currently the progress field can take 4 integer values
 * (-1: error, 0: waiting, 1: moving, 2: moved). The progress field is of type bigint
 * because we may implement a more granular, byte-level progress as a future improvement.
 */
Datum get_rebalance_progress(PG_FUNCTION_ARGS)
{
    CheckCitusVersion(ERROR);
    List* segmentList = NIL;
    TupleDesc tupdesc;
    Tuplestorestate* tupstore = SetupTuplestore(fcinfo, &tupdesc);

    /* get the addresses of all current rebalance monitors */
    List* rebalanceMonitorList =
        ProgressMonitorList(REBALANCE_ACTIVITY_MAGIC_NUMBER, &segmentList);

    ProgressMonitorData* monitor = NULL;
    foreach_declared_ptr(monitor, rebalanceMonitorList)
    {
        PlacementUpdateEventProgress* placementUpdateEvents =
            static_cast<PlacementUpdateEventProgress*>(ProgressMonitorSteps(monitor));
        HTAB* shardStatistics =
            BuildWorkerShardStatisticsHash(placementUpdateEvents, monitor->stepCount);
        HTAB* shardSizes = BuildShardSizesHash(monitor, shardStatistics);
        for (int eventIndex = 0; eventIndex < monitor->stepCount; eventIndex++) {
            PlacementUpdateEventProgress* step = placementUpdateEvents + eventIndex;
            uint64 shardId = step->shardId;
            ShardInterval* shardInterval = LoadShardInterval(shardId);

            uint64 sourceSize = WorkerShardSize(shardStatistics, step->sourceName,
                                                step->sourcePort, shardId);
            uint64 targetSize = WorkerShardSize(shardStatistics, step->targetName,
                                                step->targetPort, shardId);

            XLogRecPtr sourceLSN =
                WorkerLSN(shardStatistics, step->sourceName, step->sourcePort);
            XLogRecPtr targetLSN = WorkerShardLSN(shardStatistics, step->targetName,
                                                  step->targetPort, shardId);

            uint64 shardSize = 0;
            ShardStatistics* shardSizesStat = static_cast<ShardStatistics*>(
                hash_search(shardSizes, &shardId, HASH_FIND, NULL));
            if (shardSizesStat) {
                shardSize = shardSizesStat->totalSize;
            }

            Datum values[15];
            bool nulls[15];

            memset(values, 0, sizeof(values));
            memset(nulls, 0, sizeof(nulls));

            values[0] = monitor->processId;
            values[1] = ObjectIdGetDatum(shardInterval->relationId);
            values[2] = UInt64GetDatum(shardId);
            values[3] = UInt64GetDatum(shardSize);
            values[4] = PointerGetDatum(cstring_to_text(step->sourceName));
            values[5] = UInt32GetDatum(step->sourcePort);
            values[6] = PointerGetDatum(cstring_to_text(step->targetName));
            values[7] = UInt32GetDatum(step->targetPort);
            values[8] = UInt64GetDatum(pg_atomic_read_u64(&step->progress));
            values[9] = UInt64GetDatum(sourceSize);
            values[10] = UInt64GetDatum(targetSize);
            values[11] = PointerGetDatum(
                cstring_to_text(PlacementUpdateTypeNames[step->updateType]));
            values[12] = LSNGetDatum(sourceLSN);
            if (sourceLSN == InvalidXLogRecPtr) {
                nulls[12] = true;
            }

            values[13] = LSNGetDatum(targetLSN);
            if (targetLSN == InvalidXLogRecPtr) {
                nulls[13] = true;
            }

            values[14] = PointerGetDatum(cstring_to_text(
                PlacementUpdateStatusNames[pg_atomic_read_u64(&step->updateStatus)]));

            tuplestore_putvalues(tupstore, tupdesc, values, nulls);
        }
    }

    DetachFromDSMSegments(segmentList);

    return (Datum)0;
}

/*
 * BuildShardSizesHash creates a hash that maps a shardid to its full size
 * within the cluster. It does this by using the rebalance progress monitor
 * state to find the node the shard is currently on. It then looks up the shard
 * size in the shardStatistics hashmap for this node.
 */
static HTAB* BuildShardSizesHash(ProgressMonitorData* monitor, HTAB* shardStatistics)
{
    HASHCTL info = {.keysize = sizeof(uint64),
                    .entrysize = sizeof(ShardStatistics),
                    .hcxt = CurrentMemoryContext};

    HTAB* shardSizes =
        hash_create("ShardSizeHash", 32, &info, HASH_ELEM | HASH_CONTEXT | HASH_BLOBS);
    PlacementUpdateEventProgress* placementUpdateEvents =
        static_cast<PlacementUpdateEventProgress*>(ProgressMonitorSteps(monitor));

    for (int eventIndex = 0; eventIndex < monitor->stepCount; eventIndex++) {
        PlacementUpdateEventProgress* step = placementUpdateEvents + eventIndex;

        uint64 shardId = step->shardId;
        uint64 shardSize = 0;
        uint64 backupShardSize = 0;
        uint64 progress = pg_atomic_read_u64(&step->progress);

        uint64 sourceSize =
            WorkerShardSize(shardStatistics, step->sourceName, step->sourcePort, shardId);
        uint64 targetSize =
            WorkerShardSize(shardStatistics, step->targetName, step->targetPort, shardId);

        if (progress == REBALANCE_PROGRESS_WAITING ||
            progress == REBALANCE_PROGRESS_MOVING) {
            /*
             * If we are not done with the move, the correct shard size is the
             * size on the source.
             */
            shardSize = sourceSize;
            backupShardSize = targetSize;
        } else if (progress == REBALANCE_PROGRESS_MOVED) {
            /*
             * If we are done with the move, the correct shard size is the size
             * on the target
             */
            shardSize = targetSize;
            backupShardSize = sourceSize;
        }

        if (shardSize == 0) {
            if (backupShardSize == 0) {
                /*
                 * We don't have any useful shard size. This can happen when a
                 * shard is moved multiple times and it is not present on
                 * either of these nodes. Probably the shard is on a worker
                 * related to another event. In the weird case that this shard
                 * is on the nodes and actually is size 0, we will have no
                 * entry in the hashmap. When fetching from it we always
                 * default to 0 if no entry is found, so that's fine.
                 */
                continue;
            }

            /*
             * Because of the way we fetch shard sizes they are from a slightly
             * earlier moment than the progress state we just read from shared
             * memory. Usually this is no problem, but there exist some race
             * conditions where this matters. For example, for very quick moves
             * it is possible that even though a step is now reported as MOVED,
             * when we read the shard sizes the move had not even started yet.
             * This in turn can mean that the target size is 0 while the source
             * size is not. We try to handle such rare edge cases by falling
             * back on the other shard size if that one is not 0.
             */
            shardSize = backupShardSize;
        }

        ShardStatistics* currentWorkerStatistics = static_cast<ShardStatistics*>(
            hash_search(shardSizes, &shardId, HASH_ENTER, NULL));
        currentWorkerStatistics->totalSize = shardSize;
    }
    return shardSizes;
}

/*
 * WorkerShardSize returns the size of a shard in bytes on a worker, based on
 * the workerShardStatisticsHash.
 */
static uint64 WorkerShardSize(HTAB* workerShardStatisticsHash, char* workerName,
                              int workerPort, uint64 shardId)
{
    WorkerHashKey workerKey = {0};
    strlcpy(workerKey.hostname, workerName, MAX_NODE_LENGTH);
    workerKey.port = workerPort;

    WorkerShardStatistics* workerStats = static_cast<WorkerShardStatistics*>(
        hash_search(workerShardStatisticsHash, &workerKey, HASH_FIND, NULL));
    if (!workerStats) {
        return 0;
    }

    ShardStatistics* shardStats = static_cast<ShardStatistics*>(
        hash_search(workerStats->statistics, &shardId, HASH_FIND, NULL));
    if (!shardStats) {
        return 0;
    }
    return shardStats->totalSize;
}

/*
 * WorkerShardLSN returns the LSN of a shard on a worker, based on
 * the workerShardStatisticsHash. If there is no LSN data in the
 * statistics object, returns InvalidXLogRecPtr.
 */
static XLogRecPtr WorkerShardLSN(HTAB* workerShardStatisticsHash, char* workerName,
                                 int workerPort, uint64 shardId)
{
    WorkerHashKey workerKey = {0};
    strlcpy(workerKey.hostname, workerName, MAX_NODE_LENGTH);
    workerKey.port = workerPort;

    WorkerShardStatistics* workerStats = static_cast<WorkerShardStatistics*>(
        hash_search(workerShardStatisticsHash, &workerKey, HASH_FIND, NULL));
    if (!workerStats) {
        return InvalidXLogRecPtr;
    }

    ShardStatistics* shardStats = static_cast<ShardStatistics*>(
        hash_search(workerStats->statistics, &shardId, HASH_FIND, NULL));
    if (!shardStats) {
        return InvalidXLogRecPtr;
    }

    return shardStats->shardLSN;
}

/*
 * WorkerLSN returns the LSN of a worker, based on the workerShardStatisticsHash.
 * If there is no LSN data in the statistics object, returns InvalidXLogRecPtr.
 */
static XLogRecPtr WorkerLSN(HTAB* workerShardStatisticsHash, char* workerName,
                            int workerPort)
{
    WorkerHashKey workerKey = {0};
    strlcpy(workerKey.hostname, workerName, MAX_NODE_LENGTH);
    workerKey.port = workerPort;

    WorkerShardStatistics* workerStats = static_cast<WorkerShardStatistics*>(
        hash_search(workerShardStatisticsHash, &workerKey, HASH_FIND, NULL));
    if (!workerStats) {
        return InvalidXLogRecPtr;
    }

    return workerStats->workerLSN;
}

/*
 * BuildWorkerShardStatisticsHash returns a shard id -> shard statistics hash containing
 * sizes of shards on the source node and destination node.
 */
static HTAB* BuildWorkerShardStatisticsHash(PlacementUpdateEventProgress* steps,
                                            int stepCount)
{
    HTAB* shardsByWorker = GetMovedShardIdsByWorker(steps, stepCount, true);

    HASHCTL info = {.keysize = sizeof(WorkerHashKey),
                    .entrysize = sizeof(WorkerShardStatistics),
                    .hcxt = CurrentMemoryContext};

    HTAB* workerShardStatistics = hash_create("WorkerShardStatistics", 32, &info,
                                              HASH_ELEM | HASH_CONTEXT | HASH_BLOBS);
    WorkerShardIds* entry = NULL;

    HASH_SEQ_STATUS status;
    hash_seq_init(&status, shardsByWorker);
    while ((entry = static_cast<WorkerShardIds*>(hash_seq_search(&status))) != NULL) {
        int connectionFlags = 0;
        MultiConnection* connection = GetNodeConnection(
            connectionFlags, entry->worker.hostname, entry->worker.port);

        HTAB* statistics = GetShardStatistics(connection, entry->shardIds);

        WorkerHashKey workerKey = {0};
        strlcpy(workerKey.hostname, entry->worker.hostname, MAX_NODE_LENGTH);
        workerKey.port = entry->worker.port;

        WorkerShardStatistics* moveStat = static_cast<WorkerShardStatistics*>(
            hash_search(workerShardStatistics, &entry->worker, HASH_ENTER, NULL));
        moveStat->statistics = statistics;
        moveStat->workerLSN = GetRemoteLSN(connection);
    }

    return workerShardStatistics;
}

/*
 * GetShardStatistics fetches the statics for the given shard ids over the
 * given connection. It returns a hashmap where the keys are the shard ids and
 * the values are the statistics.
 */
static HTAB* GetShardStatistics(MultiConnection* connection, HTAB* shardIds)
{
    StringInfo query = makeStringInfo();

    appendStringInfoString(
        query, "WITH shard_names (shard_id, schema_name, table_name) AS ((VALUES ");

    bool isFirst = true;
    uint64* shardIdPtr = NULL;
    HASH_SEQ_STATUS status;
    hash_seq_init(&status, shardIds);
    while ((shardIdPtr = static_cast<uint64*>(hash_seq_search(&status))) != NULL) {
        uint64 shardId = *shardIdPtr;
        ShardInterval* shardInterval = LoadShardInterval(shardId);
        Oid relationId = shardInterval->relationId;
        char* shardName = get_rel_name(relationId);

        AppendShardIdToName(&shardName, shardId);

        Oid schemaId = get_rel_namespace(relationId);
        char* schemaName = get_namespace_name(schemaId);
        if (!isFirst) {
            appendStringInfo(query, ", ");
        }

        appendStringInfo(query, "(" UINT64_FORMAT ",%s,%s)", shardId,
                         quote_literal_cstr(schemaName), quote_literal_cstr(shardName));

        isFirst = false;
    }

    appendStringInfoString(query, "))");
    appendStringInfoString(
        query,
        " SELECT shard_id, coalesce(pg_total_relation_size(tables.relid),0), tables.lsn"

        /* for each shard in shardIds */
        " FROM shard_names"

        /* check if its name can be found in pg_class, if so return size */
        " LEFT JOIN"
        " (SELECT c.oid AS relid, c.relname, n.nspname, ss.latest_end_lsn AS lsn"
        " FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace "
        " LEFT JOIN pg_subscription_rel sr ON sr.srrelid = c.oid "
        " LEFT JOIN pg_stat_subscription ss ON sr.srsubid = ss.subid) tables"
        " ON tables.relname = shard_names.table_name AND"
        " tables.nspname = shard_names.schema_name ");

    PGresult* result = NULL;
    int queryResult = ExecuteOptionalRemoteCommand(connection, query->data, &result);
    if (queryResult != RESPONSE_OKAY) {
        ereport(ERROR, (errcode(ERRCODE_CONNECTION_FAILURE),
                        errmsg("cannot get the size because of a connection error")));
    }

    int rowCount = PQntuples(result);
    int colCount = PQnfields(result);

    /* This is not expected to ever happen, but we check just to be sure */
    if (colCount < 2) {
        ereport(ERROR,
                (errmsg("unexpected number of columns returned by: %s", query->data)));
    }

    HASHCTL info = {.keysize = sizeof(uint64),
                    .entrysize = sizeof(ShardStatistics),
                    .hcxt = CurrentMemoryContext};

    HTAB* shardStatistics = hash_create("ShardStatisticsHash", 32, &info,
                                        HASH_ELEM | HASH_CONTEXT | HASH_BLOBS);

    for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) {
        char* shardIdString = PQgetvalue(result, rowIndex, 0);
        uint64 shardId = std::strtoull(shardIdString, NULL, 10);
        char* sizeString = PQgetvalue(result, rowIndex, 1);
        uint64 totalSize = std::strtoull(sizeString, NULL, 10);

        ShardStatistics* statistics = static_cast<ShardStatistics*>(
            hash_search(shardStatistics, &shardId, HASH_ENTER, NULL));
        statistics->totalSize = totalSize;

        if (PQgetisnull(result, rowIndex, 2)) {
            statistics->shardLSN = InvalidXLogRecPtr;
        } else {
            char* LSNString = PQgetvalue(result, rowIndex, 2);
            Datum LSNDatum = DirectFunctionCall1(pg_lsn_in, CStringGetDatum(LSNString));
            statistics->shardLSN = DatumGetLSN(LSNDatum);
        }
    }

    PQclear(result);

    bool raiseErrors = true;
    ClearResults(connection, raiseErrors);

    return shardStatistics;
}

/*
 * GetMovedShardIdsByWorker groups the shard ids in the provided steps by
 * worker. It returns a hashmap that contains a set of these shard ids.
 */
static HTAB* GetMovedShardIdsByWorker(PlacementUpdateEventProgress* steps, int stepCount,
                                      bool fromSource)
{
    HASHCTL info = {.keysize = sizeof(WorkerHashKey),
                    .entrysize = sizeof(WorkerShardIds),
                    .hcxt = CurrentMemoryContext};

    HTAB* shardsByWorker = hash_create("GetRebalanceStepsByWorker", 32, &info,
                                       HASH_ELEM | HASH_CONTEXT | HASH_BLOBS);

    for (int stepIndex = 0; stepIndex < stepCount; stepIndex++) {
        PlacementUpdateEventProgress* step = &(steps[stepIndex]);

        AddToWorkerShardIdSet(shardsByWorker, step->sourceName, step->sourcePort,
                              step->shardId);

        if (pg_atomic_read_u64(&step->progress) == REBALANCE_PROGRESS_WAITING) {
            /*
             * shard move has not started so we don't need target stats for
             * this shard
             */
            continue;
        }

        AddToWorkerShardIdSet(shardsByWorker, step->targetName, step->targetPort,
                              step->shardId);
    }

    return shardsByWorker;
}

/*
 * AddToWorkerShardIdSet adds the shard id to the shard id set for the
 * specified worker in the shardsByWorker hashmap.
 */
static void AddToWorkerShardIdSet(HTAB* shardsByWorker, char* workerName, int workerPort,
                                  uint64 shardId)
{
    WorkerHashKey workerKey = {0};

    strlcpy(workerKey.hostname, workerName, MAX_NODE_LENGTH);
    workerKey.port = workerPort;

    bool isFound = false;
    WorkerShardIds* workerShardIds = static_cast<WorkerShardIds*>(
        hash_search(shardsByWorker, &workerKey, HASH_ENTER, &isFound));
    if (!isFound) {
        HASHCTL info = {.keysize = sizeof(uint64),
                        .entrysize = sizeof(uint64),
                        .hcxt = CurrentMemoryContext};

        workerShardIds->shardIds = hash_create("WorkerShardIdsSet", 32, &info,
                                               HASH_ELEM | HASH_CONTEXT | HASH_BLOBS);
    }

    hash_search(workerShardIds->shardIds, &shardId, HASH_ENTER, NULL);
}

/*
 * NonColocatedDistRelationIdList returns a list of distributed table oids, one
 * for each existing colocation group.
 */
static List* NonColocatedDistRelationIdList(void)
{
    List* relationIdList = NIL;
    List* allCitusTablesList = CitusTableTypeIdList(ANY_CITUS_TABLE_TYPE);
    Oid tableId = InvalidOid;

    /* allocate sufficient capacity for O(1) expected look-up time */
    int capacity = (int)(list_length(allCitusTablesList) / 0.75) + 1;
    int flags = HASH_ELEM | HASH_CONTEXT | HASH_BLOBS;
    HASHCTL info = {
        .keysize = sizeof(Oid), .entrysize = sizeof(Oid), .hcxt = CurrentMemoryContext};

    HTAB* alreadySelectedColocationIds =
        hash_create("RebalanceColocationIdSet", capacity, &info, flags);
    foreach_declared_oid(tableId, allCitusTablesList)
    {
        bool foundInSet = false;
        CitusTableCacheEntry* citusTableCacheEntry = GetCitusTableCacheEntry(tableId);

        if (!IsCitusTableTypeCacheEntry(citusTableCacheEntry, DISTRIBUTED_TABLE)) {
            /*
             * We're only interested in distributed tables, should ignore
             * reference tables and citus local tables.
             */
            continue;
        }

        if (citusTableCacheEntry->colocationId != INVALID_COLOCATION_ID) {
            hash_search(alreadySelectedColocationIds, &citusTableCacheEntry->colocationId,
                        HASH_ENTER, &foundInSet);
            if (foundInSet) {
                continue;
            }
        }
        relationIdList = lappend_oid(relationIdList, tableId);
    }
    return relationIdList;
}

/*
 * RebalanceTableShards rebalances the shards for the relations inside the
 * relationIdList across the different workers.
 */
static void RebalanceTableShards(RebalanceOptions* options, Oid shardReplicationModeOid)
{
    char transferMode = TRANSFER_MODE_BLOCK_WRITES;

    if (list_length(options->relationIdList) == 0) {
        EnsureReferenceTablesExistOnAllNodesExtended(transferMode);
        return;
    }

    char* operationName = "rebalance";
    if (options->drainOnly) {
        operationName = "move";
    }

    options->operationName = operationName;
    ErrorOnConcurrentRebalance(options);

    List* placementUpdateList = GetRebalanceSteps(options);

    if (transferMode == TRANSFER_MODE_AUTOMATIC) {
        /*
         * If the shard transfer mode is set to auto, we should check beforehand
         * if we are able to use logical replication to transfer shards or not.
         * We throw an error if any of the tables do not have a replica identity, which
         * is required for logical replication to replicate UPDATE and DELETE commands.
         */
        PlacementUpdateEvent* placementUpdate = NULL;
        foreach_declared_ptr(placementUpdate, placementUpdateList)
        {
            Oid relationId = RelationIdForShard(placementUpdate->shardId);
            List* colocatedTableList = ColocatedTableList(relationId);
            VerifyTablesHaveReplicaIdentity(colocatedTableList);
        }
    }

    EnsureReferenceTablesExistOnAllNodesExtended(transferMode);

    if (list_length(placementUpdateList) == 0) {
        return;
    }

    /*
     * This uses the first relationId from the list, it's only used for display
     * purposes so it does not really matter which to show
     */
    SetupRebalanceMonitor(placementUpdateList, linitial_oid(options->relationIdList),
                          REBALANCE_PROGRESS_WAITING,
                          PLACEMENT_UPDATE_STATUS_NOT_STARTED_YET);
    ExecutePlacementUpdates(placementUpdateList, shardReplicationModeOid, "Moving");
    FinalizeCurrentProgressMonitor();
}

/*
 * ErrorOnConcurrentRebalance raises an error with extra information when there is already
 * a rebalance running.
 */
static void ErrorOnConcurrentRebalance(RebalanceOptions* options)
{
    Oid relationId = InvalidOid;
    foreach_declared_oid(relationId, options->relationIdList)
    {
        /* this provides the legacy error when the lock can't be acquired */
        AcquireRebalanceColocationLock(relationId, options->operationName);
    }

    int64 jobId = 0;
    if (HasNonTerminalJobOfType("rebalance", &jobId)) {
        ereport(ERROR,
                (errmsg("A rebalance is already running as job %ld", jobId),
                 errdetail("A rebalance was already scheduled as background job")));
    }
}

/*
 * GetColocationId function returns the colocationId of the shard in a
 * PlacementUpdateEvent.
 */
static int64 GetColocationId(PlacementUpdateEvent* move)
{
    ShardInterval* shardInterval = LoadShardInterval(move->shardId);

    CitusTableCacheEntry* citusTableCacheEntry =
        GetCitusTableCacheEntry(shardInterval->relationId);

    return citusTableCacheEntry->colocationId;
}

/*
 * InitializeShardMoveDependencies function creates the hash maps that we use to track
 * the latest moves so that subsequent moves with the same properties must take a
 * dependency on them. There are two hash maps. One is for tracking the latest move
 * scheduled in a given colocation group and the other one is for tracking source nodes of
 * all moves.
 */
static ShardMoveDependencies InitializeShardMoveDependencies()
{
    ShardMoveDependencies shardMoveDependencies;
    shardMoveDependencies.colocationDependencies = CreateSimpleHashWithNameAndSize(
        int64, ShardMoveDependencyInfo, "colocationDependencyHashMap", 6);
    shardMoveDependencies.nodeDependencies = CreateSimpleHashWithNameAndSize(
        int32, ShardMoveSourceNodeHashEntry, "nodeDependencyHashMap", 6);
    return shardMoveDependencies;
}

/*
 * GenerateTaskMoveDependencyList creates and returns a List of taskIds that
 * the move must take a dependency on, given the shard move dependencies as input.
 */
static int64* GenerateTaskMoveDependencyList(PlacementUpdateEvent* move,
                                             int64 colocationId,
                                             ShardMoveDependencies shardMoveDependencies,
                                             int* nDepends)
{
    HTAB* dependsList =
        CreateSimpleHashSetWithNameAndSize(int64, "shardMoveDependencyList", 0);

    bool found;

    /* Check if there exists a move in the same colocation group scheduled earlier. */
    ShardMoveDependencyInfo* shardMoveDependencyInfo =
        static_cast<ShardMoveDependencyInfo*>(
            hash_search(shardMoveDependencies.colocationDependencies, &colocationId,
                        HASH_ENTER, &found));

    if (found) {
        hash_search(dependsList, &shardMoveDependencyInfo->taskId, HASH_ENTER, NULL);
    }

    /*
     * Check if there exists moves scheduled earlier whose source node
     * overlaps with the current move's target node.
     * The earlier/first move might make space for the later/second move.
     * So we could run out of disk space (or at least overload the node)
     * if we move the second shard to it before the first one is moved away.
     */
    ShardMoveSourceNodeHashEntry* shardMoveSourceNodeHashEntry =
        static_cast<ShardMoveSourceNodeHashEntry*>(
            hash_search(shardMoveDependencies.nodeDependencies, &move->targetNode->nodeId,
                        HASH_FIND, &found));

    if (found) {
        int64* taskId = NULL;
        foreach_declared_ptr(taskId, shardMoveSourceNodeHashEntry->taskIds)
        {
            hash_search(dependsList, taskId, HASH_ENTER, NULL);
        }
    }

    *nDepends = hash_get_num_entries(dependsList);

    int64* dependsArray = NULL;

    if (*nDepends > 0) {
        HASH_SEQ_STATUS seq;

        dependsArray = static_cast<int64*>(palloc((*nDepends) * sizeof(int64)));

        hash_seq_init(&seq, dependsList);
        int i = 0;
        int64* dependsTaskId;

        while ((dependsTaskId = (int64*)hash_seq_search(&seq)) != NULL) {
            dependsArray[i++] = *dependsTaskId;
        }
    }

    return dependsArray;
}

/*
 * UpdateShardMoveDependencies function updates the dependency maps with the latest move's
 * taskId.
 */
static void UpdateShardMoveDependencies(PlacementUpdateEvent* move, uint64 colocationId,
                                        int64 taskId,
                                        ShardMoveDependencies shardMoveDependencies)
{
    ShardMoveDependencyInfo* shardMoveDependencyInfo =
        static_cast<ShardMoveDependencyInfo*>(
            hash_search(shardMoveDependencies.colocationDependencies, &colocationId,
                        HASH_ENTER, NULL));
    shardMoveDependencyInfo->taskId = taskId;

    bool found;
    ShardMoveSourceNodeHashEntry* shardMoveSourceNodeHashEntry =
        static_cast<ShardMoveSourceNodeHashEntry*>(
            hash_search(shardMoveDependencies.nodeDependencies, &move->sourceNode->nodeId,
                        HASH_ENTER, &found));

    if (!found) {
        shardMoveSourceNodeHashEntry->taskIds = NIL;
    }

    int64* newTaskId = static_cast<int64*>(palloc0(sizeof(int64)));
    *newTaskId = taskId;
    shardMoveSourceNodeHashEntry->taskIds =
        lappend(shardMoveSourceNodeHashEntry->taskIds, newTaskId);
}

/*
 * RebalanceTableShardsBackground rebalances the shards for the relations
 * inside the relationIdList across the different workers. It does so using our
 * background job+task infrastructure.
 */
static int64 RebalanceTableShardsBackground(RebalanceOptions* options,
                                            Oid shardReplicationModeOid)
{
    if (list_length(options->relationIdList) == 0) {
        ereport(NOTICE, (errmsg("No tables to rebalance")));
        return 0;
    }

    char* operationName = "rebalance";
    if (options->drainOnly) {
        operationName = "move";
    }

    options->operationName = operationName;
    ErrorOnConcurrentRebalance(options);

    const char shardTransferMode = TRANSFER_MODE_BLOCK_WRITES;
    List* colocatedTableList = NIL;
    Oid relationId = InvalidOid;
    foreach_declared_oid(relationId, options->relationIdList)
    {
        colocatedTableList =
            list_concat(colocatedTableList, ColocatedTableList(relationId));
    }
    Oid colocatedTableId = InvalidOid;
    foreach_declared_oid(colocatedTableId, colocatedTableList)
    {
        EnsureTableOwner(colocatedTableId);
    }

    List* placementUpdateList = GetRebalanceSteps(options);

    if (list_length(placementUpdateList) == 0) {
        ereport(NOTICE, (errmsg("No moves available for rebalancing")));
        return 0;
    }

    if (shardTransferMode == TRANSFER_MODE_AUTOMATIC) {
        /*
         * If the shard transfer mode is set to auto, we should check beforehand
         * if we are able to use logical replication to transfer shards or not.
         * We throw an error if any of the tables do not have a replica identity, which
         * is required for logical replication to replicate UPDATE and DELETE commands.
         */
        PlacementUpdateEvent* placementUpdate = NULL;
        foreach_declared_ptr(placementUpdate, placementUpdateList)
        {
            relationId = RelationIdForShard(placementUpdate->shardId);
            List* colocatedTables = ColocatedTableList(relationId);
            VerifyTablesHaveReplicaIdentity(colocatedTables);
        }
    }

    DropOrphanedResourcesInSeparateTransaction();

    /* find the name of the shard transfer mode to interpolate in the scheduled command */
    Datum shardTranferModeLabelDatum =
        DirectFunctionCall1(enum_out, shardReplicationModeOid);
    char* shardTranferModeLabel = DatumGetCString(shardTranferModeLabelDatum);

    /* schedule planned moves */
    int64 jobId = CreateBackgroundJob("rebalance", "Rebalance all colocation groups");

    /* buffer used to construct the sql command for the tasks */
    StringInfoData buf = {0};
    initStringInfo(&buf);

    List* referenceTableIdList = NIL;
    int64 replicateRefTablesTaskId = 0;

    if (HasNodesWithMissingReferenceTables(&referenceTableIdList)) {
        if (shardTransferMode == TRANSFER_MODE_AUTOMATIC) {
            VerifyTablesHaveReplicaIdentity(referenceTableIdList);
        }

        /*
         * Reference tables need to be copied to (newly-added) nodes, this needs to be the
         * first task before we can move any other table.
         */
        appendStringInfo(&buf, "SELECT pg_catalog.replicate_reference_tables(%s)",
                         quote_literal_cstr(shardTranferModeLabel));

        int32 nodesInvolved[] = {0};

        /* replicate_reference_tables permissions require superuser */
        Oid superUserId = CitusExtensionOwner();
        BackgroundTask* task = ScheduleBackgroundTask(jobId, superUserId, buf.data, 0,
                                                      NULL, 0, nodesInvolved);
        replicateRefTablesTaskId = task->taskid;
    }

    PlacementUpdateEvent* move = NULL;

    ShardMoveDependencies shardMoveDependencies = InitializeShardMoveDependencies();

    foreach_declared_ptr(move, placementUpdateList)
    {
        resetStringInfo(&buf);

        appendStringInfo(&buf, "SELECT pg_catalog.spq_move_shard_placement(%ld,%u,%u,%s)",
                         move->shardId, move->sourceNode->nodeId,
                         move->targetNode->nodeId,
                         quote_literal_cstr(shardTranferModeLabel));

        int64 colocationId = GetColocationId(move);

        int nDepends = 0;

        int64* dependsArray = GenerateTaskMoveDependencyList(
            move, colocationId, shardMoveDependencies, &nDepends);

        if (nDepends == 0 && replicateRefTablesTaskId > 0) {
            nDepends = 1;
            dependsArray = static_cast<int64*>(palloc(nDepends * sizeof(int64)));
            dependsArray[0] = replicateRefTablesTaskId;
        }

        int32 nodesInvolved[2] = {0};
        nodesInvolved[0] = move->sourceNode->nodeId;
        nodesInvolved[1] = move->targetNode->nodeId;

        BackgroundTask* task = ScheduleBackgroundTask(
            jobId, GetUserId(), buf.data, nDepends, dependsArray, 2, nodesInvolved);

        UpdateShardMoveDependencies(move, colocationId, task->taskid,
                                    shardMoveDependencies);
    }

    ereport(NOTICE, (errmsg("Scheduled %d moves as job %ld",
                            list_length(placementUpdateList), jobId),
                     errdetail("Rebalance scheduled as background job")));

    return jobId;
}

/*
 * UpdateShardPlacement copies or moves a shard placement by calling
 * the corresponding functions in Citus in a subtransaction.
 */
static void UpdateShardPlacement(PlacementUpdateEvent* placementUpdateEvent,
                                 List* responsiveNodeList, Oid shardReplicationModeOid)
{
    PlacementUpdateType updateType = placementUpdateEvent->updateType;
    uint64 shardId = placementUpdateEvent->shardId;
    WorkerNode* sourceNode = placementUpdateEvent->sourceNode;
    WorkerNode* targetNode = placementUpdateEvent->targetNode;

    Datum shardTranferModeLabelDatum =
        DirectFunctionCall1(enum_out, shardReplicationModeOid);
    char* shardTranferModeLabel = DatumGetCString(shardTranferModeLabelDatum);

    StringInfo placementUpdateCommand = makeStringInfo();

    /* if target node is not responsive, don't continue */
    bool targetResponsive = WorkerNodeListContains(
        responsiveNodeList, targetNode->workerName, targetNode->workerPort);
    if (!targetResponsive) {
        ereport(ERROR, (errmsg("target node %s:%d is not responsive",
                               targetNode->workerName, targetNode->workerPort)));
    }

    /* if source node is not responsive, don't continue */
    bool sourceResponsive = WorkerNodeListContains(
        responsiveNodeList, sourceNode->workerName, sourceNode->workerPort);
    if (!sourceResponsive) {
        ereport(ERROR, (errmsg("source node %s:%d is not responsive",
                               sourceNode->workerName, sourceNode->workerPort)));
    }

    if (updateType == PLACEMENT_UPDATE_MOVE) {
        appendStringInfo(placementUpdateCommand,
                         "SELECT pg_catalog.spq_move_shard_placement(%ld,%u,%u,%s)",
                         shardId, sourceNode->nodeId, targetNode->nodeId,
                         quote_literal_cstr(shardTranferModeLabel));
    } else if (updateType == PLACEMENT_UPDATE_COPY) {
        appendStringInfo(placementUpdateCommand,
                         "SELECT pg_catalog.citus_copy_shard_placement(%ld,%u,%u,%s)",
                         shardId, sourceNode->nodeId, targetNode->nodeId,
                         quote_literal_cstr(shardTranferModeLabel));
    } else {
        ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
                        errmsg("only moving or copying shards is supported")));
    }

    UpdateColocatedShardPlacementProgress(shardId, sourceNode->workerName,
                                          sourceNode->workerPort,
                                          REBALANCE_PROGRESS_MOVING);

    /*
     * In case of failure, we throw an error such that rebalance_table_shards
     * fails early.
     */
    ExecuteRebalancerCommandInSeparateTransaction(placementUpdateCommand->data);

    UpdateColocatedShardPlacementProgress(shardId, sourceNode->workerName,
                                          sourceNode->workerPort,
                                          REBALANCE_PROGRESS_MOVED);
}

/*
 * ExecuteRebalancerCommandInSeparateTransaction runs a command in a separate
 * transaction that is commited right away. This is useful for things that you
 * don't want to rollback when the current transaction is rolled back.
 * Set true to 'useExclusiveTransactionBlock' to initiate a BEGIN and COMMIT statements.
 */
void ExecuteRebalancerCommandInSeparateTransaction(char* command)
{
    int connectionFlag = FORCE_NEW_CONNECTION;
    MultiConnection* connection =
        GetNodeConnection(connectionFlag, Session_ctx::Vars().LocalHostName,
                          g_instance.attr.attr_network.PostPortNumber);
    List* commandList = NIL;

    commandList = lappend(
        commandList, psprintf("SET LOCAL application_name TO '%s%ld'",
                              SPQ_REBALANCER_APPLICATION_NAME_PREFIX, GetGlobalPID()));

    if (Session_ctx::Vars().PropagateSessionSettingsForLoopbackConnection) {
        List* setCommands = GetSetCommandListForNewConnections();
        char* setCommand = NULL;

        foreach_declared_ptr(setCommand, setCommands)
        {
            commandList = lappend(commandList, setCommand);
        }
    }

    commandList = lappend(commandList, command);

    SendCommandListToWorkerOutsideTransactionWithConnection(connection, commandList);
    CloseConnection(connection);
}

/*
 * GetSetCommandListForNewConnections returns a list of SET statements to
 * be executed in new connections to worker nodes.
 */
static List* GetSetCommandListForNewConnections(void)
{
    List* commandList = NIL;

    int gucCount = 0;
    struct config_generic** guc_vars = get_guc_variables_compat(&gucCount);

    for (int gucIndex = 0; gucIndex < gucCount; gucIndex++) {
        struct config_generic* var = (struct config_generic*)guc_vars[gucIndex];
        if (var->source == PGC_S_SESSION && IsSettingSafeToPropagate(var->name)) {
            const char* variableValue = GetConfigOption(var->name, true, true);
            commandList = lappend(
                commandList, psprintf("SET LOCAL %s TO '%s';", var->name, variableValue));
        }
    }

    return commandList;
}

/*
 * RebalancePlacementUpdates returns a list of placement updates which makes the
 * cluster balanced. We move shards to these nodes until all nodes become utilized.
 * We consider a node under-utilized if it has less than floor((1.0 - threshold) *
 * placementCountAverage) shard placements. In each iteration we choose the node
 * with maximum number of shard placements as the source, and we choose the node
 * with minimum number of shard placements as the target. Then we choose a shard
 * which is placed in the source node but not in the target node as the shard to
 * move.
 *
 * The activeShardPlacementListList argument contains a list of lists of active shard
 * placements. Each of these lists are balanced independently. This is used to
 * make sure different colocation groups are balanced separately, so each list
 * contains the placements of a colocation group.
 */
List* RebalancePlacementUpdates(List* workerNodeList, List* activeShardPlacementListList,
                                double threshold, int32 maxShardMoves, bool drainOnly,
                                float4 improvementThreshold,
                                RebalancePlanFunctions* functions)
{
    List* rebalanceStates = NIL;
    RebalanceState* state = NULL;
    List* shardPlacementList = NIL;
    List* placementUpdateList = NIL;

    foreach_declared_ptr(shardPlacementList, activeShardPlacementListList)
    {
        state = InitRebalanceState(workerNodeList, shardPlacementList, functions);
        rebalanceStates = lappend(rebalanceStates, state);
    }

    foreach_declared_ptr(state, rebalanceStates)
    {
        state->placementUpdateList = placementUpdateList;
        MoveShardsAwayFromDisallowedNodes(state);
        placementUpdateList = state->placementUpdateList;
    }

    if (!drainOnly) {
        foreach_declared_ptr(state, rebalanceStates)
        {
            state->placementUpdateList = placementUpdateList;

            /* calculate lower bound for placement count */
            float4 averageUtilization = (state->totalCost / state->totalCapacity);
            float4 utilizationLowerBound = ((1.0 - threshold) * averageUtilization);
            float4 utilizationUpperBound = ((1.0 + threshold) * averageUtilization);

            bool moreMovesAvailable = true;
            while (list_length(state->placementUpdateList) < maxShardMoves &&
                   moreMovesAvailable) {
                moreMovesAvailable =
                    FindAndMoveShardCost(utilizationLowerBound, utilizationUpperBound,
                                         improvementThreshold, state);
            }
            placementUpdateList = state->placementUpdateList;

            if (moreMovesAvailable) {
                ereport(NOTICE, (errmsg("Stopped searching before we were out of moves. "
                                        "Please rerun the rebalancer after it's finished "
                                        "for a more optimal placement.")));
                break;
            }
        }
    }

    foreach_declared_ptr(state, rebalanceStates)
    {
        hash_destroy(state->placementsHash);
    }

    int64 ignoredMoves = 0;
    foreach_declared_ptr(state, rebalanceStates)
    {
        ignoredMoves += state->ignoredMoves;
    }

    if (ignoredMoves > 0) {
        if (Session_ctx::Vars().MaxRebalancerLoggedIgnoredMoves == -1 ||
            ignoredMoves <= Session_ctx::Vars().MaxRebalancerLoggedIgnoredMoves) {
            ereport(
                NOTICE,
                (errmsg("Ignored %ld moves, all of which are shown in notices above",
                        ignoredMoves),
                 errhint(
                     "If you do want these moves to happen, try changing "
                     "improvement_threshold to a lower value than what it is now (%g).",
                     improvementThreshold)));
        } else {
            ereport(
                NOTICE,
                (errmsg("Ignored %ld moves, %d of which are shown in notices above",
                        ignoredMoves,
                        Session_ctx::Vars().MaxRebalancerLoggedIgnoredMoves),
                 errhint(
                     "If you do want these moves to happen, try changing "
                     "improvement_threshold to a lower value than what it is now (%g).",
                     improvementThreshold)));
        }
    }
    return placementUpdateList;
}

/*
 * InitRebalanceState sets up a RebalanceState for it's arguments. The
 * RebalanceState contains the information needed to calculate shard moves.
 */
static RebalanceState* InitRebalanceState(List* workerNodeList, List* shardPlacementList,
                                          RebalancePlanFunctions* functions)
{
    ShardPlacement* placement = NULL;
    HASH_SEQ_STATUS status;
    WorkerNode* workerNode = NULL;

    RebalanceState* state = static_cast<RebalanceState*>(palloc0(sizeof(RebalanceState)));
    state->functions = functions;
    state->placementsHash = ShardPlacementsListToHash(shardPlacementList);

    /* create empty fill state for all of the worker nodes */
    foreach_declared_ptr(workerNode, workerNodeList)
    {
        NodeFillState* fillState =
            static_cast<NodeFillState*>(palloc0(sizeof(NodeFillState)));
        fillState->node = workerNode;
        fillState->capacity = functions->nodeCapacity(workerNode, functions->context);

        /*
         * Set the utilization here although the totalCost is not set yet. This
         * is needed to set the utilization to INFINITY when the capacity is 0.
         */
        fillState->utilization =
            CalculateUtilization(fillState->totalCost, fillState->capacity);
        state->fillStateListAsc = lappend(state->fillStateListAsc, fillState);
        state->fillStateListDesc = lappend(state->fillStateListDesc, fillState);
        state->totalCapacity += fillState->capacity;
    }

    /* Fill the fill states for all of the worker nodes based on the placements */
    foreach_htab(placement, &status, state->placementsHash)
    {
        ShardCost* shardCost = static_cast<ShardCost*>(palloc0(sizeof(ShardCost)));
        NodeFillState* fillState = FindFillStateForPlacement(state, placement);

        Assert(fillState != NULL);

        *shardCost = functions->shardCost(placement->shardId, functions->context);

        fillState->totalCost += shardCost->cost;
        fillState->utilization =
            CalculateUtilization(fillState->totalCost, fillState->capacity);
        fillState->shardCostListDesc = lappend(fillState->shardCostListDesc, shardCost);
        fillState->shardCostListDesc =
            SortList(fillState->shardCostListDesc, CompareShardCostDesc);

        state->totalCost += shardCost->cost;

        if (!functions->shardAllowedOnNode(placement->shardId, fillState->node,
                                           functions->context)) {
            DisallowedPlacement* disallowed =
                static_cast<DisallowedPlacement*>(palloc0(sizeof(DisallowedPlacement)));
            disallowed->shardCost = shardCost;
            disallowed->fillState = fillState;
            state->disallowedPlacementList =
                lappend(state->disallowedPlacementList, disallowed);
        }
    }
    foreach_htab_cleanup(placement, &status);

    state->fillStateListAsc = SortList(state->fillStateListAsc, CompareNodeFillStateAsc);
    state->fillStateListDesc =
        SortList(state->fillStateListDesc, CompareNodeFillStateDesc);
    CheckRebalanceStateInvariants(state);

    return state;
}

/*
 * CalculateUtilization returns INFINITY when capacity is 0 and
 * totalCost/capacity otherwise.
 */
static float4 CalculateUtilization(float4 totalCost, float4 capacity)
{
    if (capacity <= 0) {
        return INFINITY;
    }
    return totalCost / capacity;
}

/*
 * FindFillStateForPlacement finds the fillState for the workernode that
 * matches the placement.
 */
static NodeFillState* FindFillStateForPlacement(RebalanceState* state,
                                                ShardPlacement* placement)
{
    NodeFillState* fillState = NULL;

    /* Find the correct fill state to add the placement to and do that */
    foreach_declared_ptr(fillState, state->fillStateListAsc)
    {
        if (IsPlacementOnWorkerNode(placement, fillState->node)) {
            return fillState;
        }
    }
    return NULL;
}

/*
 * CompareNodeFillStateAsc can be used to sort fill states from empty to full.
 */
static int CompareNodeFillStateAsc(const void* void1, const void* void2)
{
    const NodeFillState* a = *((const NodeFillState**)void1);
    const NodeFillState* b = *((const NodeFillState**)void2);
    if (a->utilization < b->utilization) {
        return -1;
    }
    if (a->utilization > b->utilization) {
        return 1;
    }

    /*
     * If utilization prefer nodes with more capacity, since utilization will
     * grow slower on those
     */
    if (a->capacity > b->capacity) {
        return -1;
    }
    if (a->capacity < b->capacity) {
        return 1;
    }

    /* Finally differentiate by node id */
    if (a->node->nodeId < b->node->nodeId) {
        return -1;
    }
    return a->node->nodeId > b->node->nodeId;
}

/*
 * CompareNodeFillStateDesc can be used to sort fill states from full to empty.
 */
static int CompareNodeFillStateDesc(const void* a, const void* b)
{
    return -CompareNodeFillStateAsc(a, b);
}

/*
 * CompareShardCostAsc can be used to sort shard costs from low cost to high
 * cost.
 */
static int CompareShardCostAsc(const void* void1, const void* void2)
{
    const ShardCost* a = *((const ShardCost**)void1);
    const ShardCost* b = *((const ShardCost**)void2);
    if (a->cost < b->cost) {
        return -1;
    }
    if (a->cost > b->cost) {
        return 1;
    }

    /* make compare function (more) stable for tests */
    if (a->shardId > b->shardId) {
        return -1;
    }
    return a->shardId < b->shardId;
}

/*
 * CompareShardCostDesc can be used to sort shard costs from high cost to low
 * cost.
 */
static int CompareShardCostDesc(const void* a, const void* b)
{
    return -CompareShardCostAsc(a, b);
}

/*
 * MoveShardsAwayFromDisallowedNodes returns a list of placement updates that
 * move any shards that are not allowed on their current node to a node that
 * they are allowed on.
 */
static void MoveShardsAwayFromDisallowedNodes(RebalanceState* state)
{
    DisallowedPlacement* disallowedPlacement = NULL;

    state->disallowedPlacementList =
        SortList(state->disallowedPlacementList, CompareDisallowedPlacementDesc);

    /* Move shards off of nodes they are not allowed on */
    foreach_declared_ptr(disallowedPlacement, state->disallowedPlacementList)
    {
        NodeFillState* targetFillState =
            FindAllowedTargetFillState(state, disallowedPlacement->shardCost->shardId);
        if (targetFillState == NULL) {
            ereport(WARNING, (errmsg("Not allowed to move shard " UINT64_FORMAT
                                     " anywhere from %s:%d",
                                     disallowedPlacement->shardCost->shardId,
                                     disallowedPlacement->fillState->node->workerName,
                                     disallowedPlacement->fillState->node->workerPort)));
            continue;
        }
        MoveShardCost(disallowedPlacement->fillState, targetFillState,
                      disallowedPlacement->shardCost, state);
    }
}

/*
 * CompareDisallowedPlacementAsc can be used to sort disallowed placements from
 * low cost to high cost.
 */
static int CompareDisallowedPlacementAsc(const void* void1, const void* void2)
{
    const DisallowedPlacement* a = *((const DisallowedPlacement**)void1);
    const DisallowedPlacement* b = *((const DisallowedPlacement**)void2);
    return CompareShardCostAsc(&(a->shardCost), &(b->shardCost));
}

/*
 * CompareDisallowedPlacementDesc can be used to sort disallowed placements from
 * high cost to low cost.
 */
static int CompareDisallowedPlacementDesc(const void* a, const void* b)
{
    return -CompareDisallowedPlacementAsc(a, b);
}

/*
 * FindAllowedTargetFillState finds the first fill state in fillStateListAsc
 * where the shard can be moved to.
 */
static NodeFillState* FindAllowedTargetFillState(RebalanceState* state, uint64 shardId)
{
    NodeFillState* targetFillState = NULL;
    foreach_declared_ptr(targetFillState, state->fillStateListAsc)
    {
        bool hasShard =
            PlacementsHashFind(state->placementsHash, shardId, targetFillState->node);
        if (!hasShard && state->functions->shardAllowedOnNode(
                             shardId, targetFillState->node, state->functions->context)) {
            bool targetHasShard =
                PlacementsHashFind(state->placementsHash, shardId, targetFillState->node);

            /* skip if the shard is already placed on the target node */
            if (!targetHasShard) {
                return targetFillState;
            }
        }
    }
    return NULL;
}

/*
 * MoveShardCost moves a shardcost from the source to the target fill states
 * and updates the RebalanceState accordingly. What it does in detail is:
 * 1. add a placement update to state->placementUpdateList
 * 2. update state->placementsHash
 * 3. update totalcost, utilization and shardCostListDesc in source and target
 * 4. resort state->fillStateListAsc/Desc
 */
static void MoveShardCost(NodeFillState* sourceFillState, NodeFillState* targetFillState,
                          ShardCost* shardCost, RebalanceState* state)
{
    uint64 shardIdToMove = shardCost->shardId;

    /* construct the placement update */
    PlacementUpdateEvent* placementUpdateEvent =
        static_cast<PlacementUpdateEvent*>(palloc0(sizeof(PlacementUpdateEvent)));
    placementUpdateEvent->updateType = PLACEMENT_UPDATE_MOVE;
    placementUpdateEvent->shardId = shardIdToMove;
    placementUpdateEvent->sourceNode = sourceFillState->node;
    placementUpdateEvent->targetNode = targetFillState->node;

    /* record the placement update */
    state->placementUpdateList =
        lappend(state->placementUpdateList, placementUpdateEvent);

    /* update the placements hash and the node shard lists */
    PlacementsHashRemove(state->placementsHash, shardIdToMove, sourceFillState->node);
    PlacementsHashEnter(state->placementsHash, shardIdToMove, targetFillState->node);

    sourceFillState->totalCost -= shardCost->cost;
    sourceFillState->utilization =
        CalculateUtilization(sourceFillState->totalCost, sourceFillState->capacity);
    sourceFillState->shardCostListDesc =
        list_delete_ptr(sourceFillState->shardCostListDesc, shardCost);

    targetFillState->totalCost += shardCost->cost;
    targetFillState->utilization =
        CalculateUtilization(targetFillState->totalCost, targetFillState->capacity);
    targetFillState->shardCostListDesc =
        lappend(targetFillState->shardCostListDesc, shardCost);
    targetFillState->shardCostListDesc =
        SortList(targetFillState->shardCostListDesc, CompareShardCostDesc);

    state->fillStateListAsc = SortList(state->fillStateListAsc, CompareNodeFillStateAsc);
    state->fillStateListDesc =
        SortList(state->fillStateListDesc, CompareNodeFillStateDesc);
    CheckRebalanceStateInvariants(state);
}

/*
 * FindAndMoveShardCost is the main rebalancing algorithm. This takes the
 * current state and returns a list with a new move appended that improves the
 * balance of shards. The algorithm is greedy and will use the first new move
 * that improves the balance. It finds nodes by trying to move a shard from the
 * most utilized node (highest utilization) to the emptiest node (lowest
 * utilization). If no moves are possible it will try the second emptiest node
 * until it tried all of them. Then it wil try the second fullest node. If it
 * was able to find a move it will return true and false if it couldn't.
 *
 * This algorithm won't necessarily result in the best possible balance. Getting
 * the best balance is an NP problem, so it's not feasible to go for the best
 * balance. This algorithm was chosen because of the following reasons:
 * 1. Literature research showed that similar problems would get within 2X of
 *    the optimal balance with a greedy algoritm.
 * 2. Every move will always improve the balance. So if the user stops a
 *    rebalance midway through, they will never be in a worse situation than
 *    before.
 * 3. It's pretty easy to reason about.
 * 4. It's simple to implement.
 *
 * utilizationLowerBound and utilizationUpperBound are used to indicate what
 * the target utilization range of all nodes is. If they are within this range,
 * then balance is good enough. If all nodes are in this range then the cluster
 * is considered balanced and no more moves are done. This is mostly useful for
 * the by_disk_size rebalance strategy. If we wouldn't have this then the
 * rebalancer could become flappy in certain cases.
 *
 * improvementThreshold is a threshold that can be used to ignore moves when
 * they only improve the balance a little relative to the cost of the shard.
 * Again this is mostly useful for the by_disk_size rebalance strategy.
 * Without this threshold the rebalancer would move a shard of 1TB when this
 * move only improves the cluster by 10GB.
 */
static bool FindAndMoveShardCost(float4 utilizationLowerBound,
                                 float4 utilizationUpperBound,
                                 float4 improvementThreshold, RebalanceState* state)
{
    NodeFillState* sourceFillState = NULL;
    NodeFillState* targetFillState = NULL;

    /*
     * find a source node for the move, starting at the node with the highest
     * utilization
     */
    foreach_declared_ptr(sourceFillState, state->fillStateListDesc)
    {
        /* Don't move shards away from nodes that are already too empty, we're
         * done searching */
        if (sourceFillState->utilization <= utilizationLowerBound) {
            return false;
        }

        /* find a target node for the move, starting at the node with the
         * lowest utilization */
        foreach_declared_ptr(targetFillState, state->fillStateListAsc)
        {
            ShardCost* shardCost = NULL;

            /* Don't add more shards to nodes that are already at the upper
             * bound. We should try the next source node now because further
             * target nodes will also be above the upper bound */
            if (targetFillState->utilization >= utilizationUpperBound) {
                break;
            }

            /* Don't move a shard between nodes that both have decent
             * utilization. We should try the next source node now because
             * further target nodes will also have have decent utilization */
            if (targetFillState->utilization >= utilizationLowerBound &&
                sourceFillState->utilization <= utilizationUpperBound) {
                break;
            }

            /* find a shardcost that can be moved between between nodes that
             * makes the cost distribution more equal */
            foreach_declared_ptr(shardCost, sourceFillState->shardCostListDesc)
            {
                bool targetHasShard = PlacementsHashFind(
                    state->placementsHash, shardCost->shardId, targetFillState->node);
                float4 newTargetTotalCost = targetFillState->totalCost + shardCost->cost;
                float4 newTargetUtilization =
                    CalculateUtilization(newTargetTotalCost, targetFillState->capacity);
                float4 newSourceTotalCost = sourceFillState->totalCost - shardCost->cost;
                float4 newSourceUtilization =
                    CalculateUtilization(newSourceTotalCost, sourceFillState->capacity);

                /* Skip shards that already are on the node */
                if (targetHasShard) {
                    continue;
                }

                /* Skip shards that already are not allowed on the node */
                if (!state->functions->shardAllowedOnNode(shardCost->shardId,
                                                          targetFillState->node,
                                                          state->functions->context)) {
                    continue;
                }

                /*
                 * If the target is still less utilized than the source, then
                 * this is clearly a good move. And if they are equally
                 * utilized too.
                 */
                if (newTargetUtilization <= newSourceUtilization) {
                    MoveShardCost(sourceFillState, targetFillState, shardCost, state);
                    return true;
                }

                /*
                 * The target is now more utilized than the source. So we need
                 * to determine if the move is a net positive for the overall
                 * cost distribution. This means that the new highest
                 * utilization of source and target is lower than the previous
                 * highest, or the highest utilization is the same, but the
                 * lowest increased.
                 */
                if (newTargetUtilization > sourceFillState->utilization) {
                    continue;
                }
                if (newTargetUtilization == sourceFillState->utilization &&
                    newSourceUtilization <=
                        targetFillState->utilization) /* lgtm[cpp/equality-on-floats] */
                {
                    /*
                     * this can trigger when capacity of the nodes is not the
                     * same. Example (also a test):
                     * - node with capacity 3
                     * - node with capacity 1
                     * - 3 shards with cost 1
                     * Best distribution would be 2 shards on node with
                     * capacity 3 and one on node with capacity 1
                     */
                    continue;
                }

                /*
                 * fmaxf and fminf here are only needed for cases when nodes
                 * have different capacities. If they are the same, then both
                 * arguments are equal.
                 */
                float4 utilizationImprovement =
                    fmaxf(sourceFillState->utilization - newTargetUtilization,
                          newSourceUtilization - targetFillState->utilization);
                float4 utilizationAddedByShard =
                    fminf(newTargetUtilization - targetFillState->utilization,
                          sourceFillState->utilization - newSourceUtilization);

                /*
                 * If the shard causes a lot of utilization, but the
                 * improvement which is gained by moving it is small, then we
                 * ignore the move. Probably there are other shards that are
                 * better candidates, and in any case it's probably not worth
                 * the effort to move the this shard.
                 *
                 * One of the main cases this tries to avoid is the rebalancer
                 * moving a very large shard with the "by_disk_size" strategy
                 * when that only gives a small benefit in data distribution.
                 */
                float4 normalizedUtilizationImprovement =
                    utilizationImprovement / utilizationAddedByShard;
                if (normalizedUtilizationImprovement < improvementThreshold) {
                    state->ignoredMoves++;
                    if (Session_ctx::Vars().MaxRebalancerLoggedIgnoredMoves == -1 ||
                        state->ignoredMoves <=
                            Session_ctx::Vars().MaxRebalancerLoggedIgnoredMoves) {
                        ereport(
                            NOTICE,
                            (errmsg("Ignoring move of shard %ld from %s:%d to %s:%d, "
                                    "because the move only brings a small improvement "
                                    "relative to the shard its size",
                                    shardCost->shardId, sourceFillState->node->workerName,
                                    sourceFillState->node->workerPort,
                                    targetFillState->node->workerName,
                                    targetFillState->node->workerPort),
                             errdetail("The balance improvement of %g is lower than the "
                                       "improvement_threshold of %g",
                                       normalizedUtilizationImprovement,
                                       improvementThreshold)));
                    }
                    continue;
                }
                MoveShardCost(sourceFillState, targetFillState, shardCost, state);
                return true;
            }
        }
    }
    return false;
}

/*
 * ReplicationPlacementUpdates returns a list of placement updates which
 * replicates shard placements that need re-replication. To do this, the
 * function loops over the active shard placements, and for each shard placement
 * which needs to be re-replicated, it chooses an active worker node with
 * smallest number of shards as the target node.
 */
List* ReplicationPlacementUpdates(List* workerNodeList, List* activeShardPlacementList,
                                  int shardReplicationFactor)
{
    List* placementUpdateList = NIL;
    ListCell* shardPlacementCell = NULL;
    uint32 workerNodeIndex = 0;
    HTAB* placementsHash = ShardPlacementsListToHash(activeShardPlacementList);
    uint32 workerNodeCount = list_length(workerNodeList);

    /* get number of shards per node */
    uint32* shardCountArray =
        static_cast<uint32*>(palloc0(workerNodeCount * sizeof(uint32)));
    foreach (shardPlacementCell, activeShardPlacementList) {
        ShardPlacement* placement =
            static_cast<ShardPlacement*>(lfirst(shardPlacementCell));

        for (workerNodeIndex = 0; workerNodeIndex < workerNodeCount; workerNodeIndex++) {
            WorkerNode* node =
                static_cast<WorkerNode*>(list_nth(workerNodeList, workerNodeIndex));
            if (strncmp(node->workerName, placement->nodeName, WORKER_LENGTH) == 0 &&
                node->workerPort == placement->nodePort) {
                shardCountArray[workerNodeIndex]++;
                break;
            }
        }
    }

    foreach (shardPlacementCell, activeShardPlacementList) {
        WorkerNode* sourceNode = NULL;
        WorkerNode* targetNode = NULL;
        uint32 targetNodeShardCount = UINT_MAX;
        uint32 targetNodeIndex = 0;

        ShardPlacement* placement = (ShardPlacement*)lfirst(shardPlacementCell);
        uint64 shardId = placement->shardId;

        /* skip the shard placement if it has enough replications */
        int activePlacementCount =
            ShardActivePlacementCount(placementsHash, shardId, workerNodeList);
        if (activePlacementCount >= shardReplicationFactor) {
            continue;
        }

        /*
         * We can copy the shard from any active worker node that contains the
         * shard.
         */
        for (workerNodeIndex = 0; workerNodeIndex < workerNodeCount; workerNodeIndex++) {
            WorkerNode* workerNode =
                static_cast<WorkerNode*>(list_nth(workerNodeList, workerNodeIndex));

            bool placementExists =
                PlacementsHashFind(placementsHash, shardId, workerNode);
            if (placementExists) {
                sourceNode = workerNode;
                break;
            }
        }

        /*
         * If we couldn't find any worker node which contains the shard, then
         * all copies of the shard are list and we should error out.
         */
        if (sourceNode == NULL) {
            ereport(ERROR, (errmsg("could not find a source for shard " UINT64_FORMAT,
                                   shardId)));
        }

        /*
         * We can copy the shard to any worker node that doesn't contain the shard.
         * Among such worker nodes, we choose the worker node with minimum shard
         * count as the target.
         */
        for (workerNodeIndex = 0; workerNodeIndex < workerNodeCount; workerNodeIndex++) {
            WorkerNode* workerNode =
                static_cast<WorkerNode*>(list_nth(workerNodeList, workerNodeIndex));

            if (!NodeCanHaveDistTablePlacements(workerNode)) {
                /* never replicate placements to nodes that should not have placements */
                continue;
            }

            /* skip this node if it already contains the shard */
            bool placementExists =
                PlacementsHashFind(placementsHash, shardId, workerNode);
            if (placementExists) {
                continue;
            }

            /* compare and change the target node */
            if (shardCountArray[workerNodeIndex] < targetNodeShardCount) {
                targetNode = workerNode;
                targetNodeShardCount = shardCountArray[workerNodeIndex];
                targetNodeIndex = workerNodeIndex;
            }
        }

        /*
         * If there is no worker node which doesn't contain the shard, then the
         * shard replication factor is greater than number of worker nodes, and
         * we should error out.
         */
        if (targetNode == NULL) {
            ereport(ERROR, (errmsg("could not find a target for shard " UINT64_FORMAT,
                                   shardId)));
        }

        /* construct the placement update */
        PlacementUpdateEvent* placementUpdateEvent =
            static_cast<PlacementUpdateEvent*>(palloc0(sizeof(PlacementUpdateEvent)));
        placementUpdateEvent->updateType = PLACEMENT_UPDATE_COPY;
        placementUpdateEvent->shardId = shardId;
        placementUpdateEvent->sourceNode = sourceNode;
        placementUpdateEvent->targetNode = targetNode;

        /* record the placement update */
        placementUpdateList = lappend(placementUpdateList, placementUpdateEvent);

        /* update the placements hash and the shard count array */
        PlacementsHashEnter(placementsHash, shardId, targetNode);
        shardCountArray[targetNodeIndex]++;
    }

    hash_destroy(placementsHash);

    return placementUpdateList;
}

/*
 * ShardActivePlacementCount returns the number of active placements for the
 * given shard which are placed at the active worker nodes.
 */
static int ShardActivePlacementCount(HTAB* activePlacementsHash, uint64 shardId,
                                     List* activeWorkerNodeList)
{
    int shardActivePlacementCount = 0;
    ListCell* workerNodeCell = NULL;

    foreach (workerNodeCell, activeWorkerNodeList) {
        WorkerNode* workerNode = static_cast<WorkerNode*>(lfirst(workerNodeCell));
        bool placementExists =
            PlacementsHashFind(activePlacementsHash, shardId, workerNode);
        if (placementExists) {
            shardActivePlacementCount++;
        }
    }

    return shardActivePlacementCount;
}

/*
 * ShardPlacementsListToHash creates and returns a hash set from a shard
 * placement list.
 */
static HTAB* ShardPlacementsListToHash(List* shardPlacementList)
{
    ListCell* shardPlacementCell = NULL;
    HASHCTL info;
    int shardPlacementCount = list_length(shardPlacementList);

    memset(&info, 0, sizeof(info));
    info.keysize = sizeof(ShardPlacement);
    info.entrysize = sizeof(ShardPlacement);
    info.hash = PlacementsHashHashCode;
    info.match = PlacementsHashCompare;
    info.hcxt = CurrentMemoryContext;
    int hashFlags = (HASH_ELEM | HASH_FUNCTION | HASH_COMPARE | HASH_CONTEXT);

    HTAB* shardPlacementsHash =
        hash_create("ActivePlacements Hash", shardPlacementCount, &info, hashFlags);

    foreach (shardPlacementCell, shardPlacementList) {
        ShardPlacement* shardPlacement = (ShardPlacement*)lfirst(shardPlacementCell);
        void* hashKey = (void*)shardPlacement;
        hash_search(shardPlacementsHash, hashKey, HASH_ENTER, NULL);
    }

    return shardPlacementsHash;
}

/*
 * PlacementsHashFind returns true if there exists a shard placement with the
 * given workerNode and shard id in the given placements hash, otherwise it
 * returns false.
 */
static bool PlacementsHashFind(HTAB* placementsHash, uint64 shardId,
                               WorkerNode* workerNode)
{
    bool placementFound = false;

    ShardPlacement shardPlacement;
    memset(&shardPlacement, 0, sizeof(shardPlacement));

    shardPlacement.shardId = shardId;
    shardPlacement.nodeName = workerNode->workerName;
    shardPlacement.nodePort = workerNode->workerPort;

    void* hashKey = (void*)(&shardPlacement);
    hash_search(placementsHash, hashKey, HASH_FIND, &placementFound);

    return placementFound;
}

/*
 * PlacementsHashEnter enters a shard placement for the given worker node and
 * shard id to the given placements hash.
 */
static void PlacementsHashEnter(HTAB* placementsHash, uint64 shardId,
                                WorkerNode* workerNode)
{
    ShardPlacement shardPlacement;
    memset(&shardPlacement, 0, sizeof(shardPlacement));

    shardPlacement.shardId = shardId;
    shardPlacement.nodeName = workerNode->workerName;
    shardPlacement.nodePort = workerNode->workerPort;

    void* hashKey = (void*)(&shardPlacement);
    hash_search(placementsHash, hashKey, HASH_ENTER, NULL);
}

/*
 * PlacementsHashRemove removes the shard placement for the given worker node and
 * shard id from the given placements hash.
 */
static void PlacementsHashRemove(HTAB* placementsHash, uint64 shardId,
                                 WorkerNode* workerNode)
{
    ShardPlacement shardPlacement;
    memset(&shardPlacement, 0, sizeof(shardPlacement));

    shardPlacement.shardId = shardId;
    shardPlacement.nodeName = workerNode->workerName;
    shardPlacement.nodePort = workerNode->workerPort;

    void* hashKey = (void*)(&shardPlacement);
    hash_search(placementsHash, hashKey, HASH_REMOVE, NULL);
}

/*
 * PlacementsHashCompare compares two shard placements using shard id, node name,
 * and node port number.
 */
static int PlacementsHashCompare(const void* lhsKey, const void* rhsKey, Size keySize)
{
    const ShardPlacement* placementLhs = (const ShardPlacement*)lhsKey;
    const ShardPlacement* placementRhs = (const ShardPlacement*)rhsKey;

    int shardIdCompare = 0;

    /* first, compare by shard id */
    if (placementLhs->shardId < placementRhs->shardId) {
        shardIdCompare = -1;
    } else if (placementLhs->shardId > placementRhs->shardId) {
        shardIdCompare = 1;
    }

    if (shardIdCompare != 0) {
        return shardIdCompare;
    }

    /* then, compare by node name */
    int nodeNameCompare =
        strncmp(placementLhs->nodeName, placementRhs->nodeName, WORKER_LENGTH);
    if (nodeNameCompare != 0) {
        return nodeNameCompare;
    }

    /* finally, compare by node port */
    int nodePortCompare = placementLhs->nodePort - placementRhs->nodePort;
    return nodePortCompare;
}

/*
 * PlacementsHashHashCode computes the hash code for a shard placement from the
 * placement's shard id, node name, and node port number.
 */
static uint32 PlacementsHashHashCode(const void* key, Size keySize)
{
    const ShardPlacement* placement = (const ShardPlacement*)key;
    const uint64* shardId = &(placement->shardId);
    const char* nodeName = placement->nodeName;
    const uint32* nodePort = &(placement->nodePort);

    /* standard hash function outlined in Effective Java, Item 8 */
    uint32 result = 17;
    result = 37 * result + tag_hash(shardId, sizeof(uint64));
    result = 37 * result + string_hash(nodeName, WORKER_LENGTH);
    result = 37 * result + tag_hash(nodePort, sizeof(uint32));

    return result;
}

/* WorkerNodeListContains checks if the worker node exists in the given list. */
static bool WorkerNodeListContains(List* workerNodeList, const char* workerName,
                                   uint32 workerPort)
{
    bool workerNodeListContains = false;
    ListCell* workerNodeCell = NULL;

    foreach (workerNodeCell, workerNodeList) {
        WorkerNode* workerNode = (WorkerNode*)lfirst(workerNodeCell);

        if ((strncmp(workerNode->workerName, workerName, WORKER_LENGTH) == 0) &&
            (workerNode->workerPort == workerPort)) {
            workerNodeListContains = true;
            break;
        }
    }

    return workerNodeListContains;
}

/*
 * UpdateColocatedShardPlacementProgress updates the progress of the given placement,
 * along with its colocated placements, to the given state.
 */
static void UpdateColocatedShardPlacementProgress(uint64 shardId, char* sourceName,
                                                  int sourcePort, uint64 progress)
{
    ProgressMonitorData* header = GetCurrentProgressMonitor();

    if (header != NULL) {
        PlacementUpdateEventProgress* steps =
            static_cast<PlacementUpdateEventProgress*>(ProgressMonitorSteps(header));
        ListCell* colocatedShardIntervalCell = NULL;

        ShardInterval* shardInterval = LoadShardInterval(shardId);
        List* colocatedShardIntervalList = ColocatedShardIntervalList(shardInterval);

        for (int moveIndex = 0; moveIndex < header->stepCount; moveIndex++) {
            PlacementUpdateEventProgress* step = steps + moveIndex;
            uint64 currentShardId = step->shardId;
            bool colocatedShard = false;

            foreach (colocatedShardIntervalCell, colocatedShardIntervalList) {
                ShardInterval* candidateShard =
                    static_cast<ShardInterval*>(lfirst(colocatedShardIntervalCell));
                if (candidateShard->shardId == currentShardId) {
                    colocatedShard = true;
                    break;
                }
            }

            if (colocatedShard && strcmp(step->sourceName, sourceName) == 0 &&
                step->sourcePort == sourcePort) {
                pg_atomic_write_u64(&step->progress, progress);
            }
        }
    }
}

/*
 * pg_dist_rebalance_strategy_enterprise_check is a now removed function, but
 * to avoid issues during upgrades a C stub is kept.
 */
Datum pg_dist_rebalance_strategy_enterprise_check(PG_FUNCTION_ARGS)
{
    PG_RETURN_VOID();
}

/*
 * citus_validate_rebalance_strategy_functions checks all the functions for
 * their correct signature.
 *
 * SQL signature:
 *
 * citus_validate_rebalance_strategy_functions(
 *     shard_cost_function regproc,
 *     node_capacity_function regproc,
 *     shard_allowed_on_node_function regproc,
 * ) RETURNS VOID
 */
Datum citus_validate_rebalance_strategy_functions(PG_FUNCTION_ARGS)
{
    CheckCitusVersion(ERROR);
    EnsureShardCostUDF(PG_GETARG_OID(0));
    EnsureNodeCapacityUDF(PG_GETARG_OID(1));
    EnsureShardAllowedOnNodeUDF(PG_GETARG_OID(2));
    PG_RETURN_VOID();
}

/*
 * EnsureShardCostUDF checks that the UDF matching the oid has the correct
 * signature to be used as a ShardCost function. The expected signature is:
 *
 * shard_cost(shardid bigint) returns float4
 */
static void EnsureShardCostUDF(Oid functionOid)
{
    HeapTuple proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(functionOid));
    if (!HeapTupleIsValid(proctup)) {
        ereport(ERROR, (errmsg("cache lookup failed for shard_cost_function with oid %u",
                               functionOid)));
    }
    Form_pg_proc procForm = (Form_pg_proc)GETSTRUCT(proctup);
    char* name = NameStr(procForm->proname);
    if (procForm->pronargs != 1) {
        ereport(ERROR, (errmsg("signature for shard_cost_function is incorrect"),
                        errdetail("number of arguments of %s should be 1, not %i", name,
                                  procForm->pronargs)));
    }
    if (procForm->proargtypes.values[0] != INT8OID) {
        ereport(ERROR, (errmsg("signature for shard_cost_function is incorrect"),
                        errdetail("argument type of %s should be bigint", name)));
    }
    if (procForm->prorettype != FLOAT4OID) {
        ereport(ERROR, (errmsg("signature for shard_cost_function is incorrect"),
                        errdetail("return type of %s should be real", name)));
    }
    ReleaseSysCache(proctup);
}

/*
 * EnsureNodeCapacityUDF checks that the UDF matching the oid has the correct
 * signature to be used as a NodeCapacity function. The expected signature is:
 *
 * node_capacity(nodeid int) returns float4
 */
static void EnsureNodeCapacityUDF(Oid functionOid)
{
    HeapTuple proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(functionOid));
    if (!HeapTupleIsValid(proctup)) {
        ereport(ERROR,
                (errmsg("cache lookup failed for node_capacity_function with oid %u",
                        functionOid)));
    }
    Form_pg_proc procForm = (Form_pg_proc)GETSTRUCT(proctup);
    char* name = NameStr(procForm->proname);
    if (procForm->pronargs != 1) {
        ereport(ERROR, (errmsg("signature for node_capacity_function is incorrect"),
                        errdetail("number of arguments of %s should be 1, not %i", name,
                                  procForm->pronargs)));
    }
    if (procForm->proargtypes.values[0] != INT4OID) {
        ereport(ERROR, (errmsg("signature for node_capacity_function is incorrect"),
                        errdetail("argument type of %s should be int", name)));
    }
    if (procForm->prorettype != FLOAT4OID) {
        ereport(ERROR, (errmsg("signature for node_capacity_function is incorrect"),
                        errdetail("return type of %s should be real", name)));
    }
    ReleaseSysCache(proctup);
}

/*
 * EnsureShardAllowedOnNodeUDF checks that the UDF matching the oid has the correct
 * signature to be used as a NodeCapacity function. The expected signature is:
 *
 * shard_allowed_on_node(shardid bigint, nodeid int) returns boolean
 */
static void EnsureShardAllowedOnNodeUDF(Oid functionOid)
{
    HeapTuple proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(functionOid));
    if (!HeapTupleIsValid(proctup)) {
        ereport(
            ERROR,
            (errmsg("cache lookup failed for shard_allowed_on_node_function with oid %u",
                    functionOid)));
    }
    Form_pg_proc procForm = (Form_pg_proc)GETSTRUCT(proctup);
    char* name = NameStr(procForm->proname);
    if (procForm->pronargs != 2) {
        ereport(ERROR,
                (errmsg("signature for shard_allowed_on_node_function is incorrect"),
                 errdetail("number of arguments of %s should be 2, not %i", name,
                           procForm->pronargs)));
    }
    if (procForm->proargtypes.values[0] != INT8OID) {
        ereport(ERROR,
                (errmsg("signature for shard_allowed_on_node_function is incorrect"),
                 errdetail("type of first argument of %s should be bigint", name)));
    }
    if (procForm->proargtypes.values[1] != INT4OID) {
        ereport(ERROR,
                (errmsg("signature for shard_allowed_on_node_function is incorrect"),
                 errdetail("type of second argument of %s should be int", name)));
    }
    if (procForm->prorettype != BOOLOID) {
        ereport(ERROR,
                (errmsg("signature for shard_allowed_on_node_function is incorrect"),
                 errdetail("return type of %s should be boolean", name)));
    }
    ReleaseSysCache(proctup);
}

/*
 * GetRemoteLSN executes a command that returns a single LSN over the given connection
 * and returns it as an XLogRecPtr (uint64).
 */
static XLogRecPtr GetRemoteLSN(MultiConnection* connection)
{
    const char* command = "SELECT pg_current_wal_lsn()";
    bool raiseInterrupts = false;
    XLogRecPtr remoteLogPosition = InvalidXLogRecPtr;

    int querySent = SendRemoteCommand(connection, command);
    if (querySent == 0) {
        ReportConnectionError(connection, ERROR);
    }

    PGresult* result = GetRemoteCommandResult(connection, raiseInterrupts);
    if (!IsResponseOK(result)) {
        ReportResultError(connection, result, ERROR);
    }

    int rowCount = PQntuples(result);
    if (rowCount != 1) {
        PQclear(result);
        ForgetResults(connection);
        return InvalidXLogRecPtr;
    }

    int colCount = PQnfields(result);
    if (colCount != 1) {
        ereport(ERROR, (errmsg("unexpected number of columns returned by: %s", command)));
    }

    if (!PQgetisnull(result, 0, 0)) {
        char* resultString = PQgetvalue(result, 0, 0);
        Datum remoteLogPositionDatum =
            DirectFunctionCall1Coll(pg_lsn_in, InvalidOid, CStringGetDatum(resultString));
        remoteLogPosition = DatumGetLSN(remoteLogPositionDatum);
    }

    PQclear(result);
    ForgetResults(connection);

    return remoteLogPosition;
}